From 5e3a7f30f7ace8d1909ed39a1b9f38714172b902 Mon Sep 17 00:00:00 2001 From: Cold <16419454+cold502@user.noreply.gitee.com> Date: Wed, 10 Dec 2025 18:02:31 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0es=E5=BD=92=E6=A1=A3=20?= =?UTF-8?q?=E5=88=86=E5=B8=83=E5=BC=8F=E5=92=8Cconstants=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- config/constants.go | 8 ++ mongo/mongo.go | 57 ++++++++++-- ragflow/client.go | 73 +++++++++------- ragflow/worker_pool.go | 192 ++++++++++++----------------------------- redis/redis.go | 28 ++++++ 6 files changed, 186 insertions(+), 174 deletions(-) diff --git a/.gitignore b/.gitignore index aeb617a..5e12a7a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ rabbitmq/开发指南.md ragflow/agent文档.md ragflow/README_GLOBAL.md redis/stream使用示例.md -ragflow/client_http.go +ragflow/client_backup.go.bak diff --git a/config/constants.go b/config/constants.go index 0cf2867..8963e8d 100644 --- a/config/constants.go +++ b/config/constants.go @@ -32,3 +32,11 @@ var ArchiveDelay = 3600 // HistoryContextLimit 读取历史对话轮数(用于新 Session 上下文注入) var HistoryContextLimit int64 = 5 + +// -------------------- Stream 消费配置 -------------------- + +// DefaultBatchSize 批量读取消息数量(削峰填谷) +var DefaultBatchSize int64 = 200 + +// DefaultBlockTimeout 阻塞超时时间(毫秒) +var DefaultBlockTimeout int64 = 2000 diff --git a/mongo/mongo.go b/mongo/mongo.go index 8a43cf2..f1f9eb9 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -8,6 +8,7 @@ import ( "time" "gitee.com/red-future---jilin-g/common/consts" + "gitee.com/red-future---jilin-g/common/do" "gitee.com/red-future---jilin-g/common/redis" "gitee.com/red-future---jilin-g/common/utils" "github.com/gogf/gf/v2/errors/gerror" @@ -21,7 +22,12 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" ) -var db = new(mongo.Database) +var db *mongo.Database + +// GetDB 获取 MongoDB 数据库实例 +func GetDB() *mongo.Database { + return db +} func init() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -88,12 +94,49 @@ func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOne return } +// GetTenantInfo 获取租户信息 +// 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表 +func GetTenantInfo(ctx context.Context) (user do.User, err error) { + // 1. 优先从 token 获取 + user, err = utils.GetUserInfo(ctx) + if err == nil { + return + } + + // 2. token 获取失败,尝试从请求参数获取 customerServiceId + req := g.RequestFromCtx(ctx) + if req == nil { + return user, gerror.New("无法获取租户信息:无 token 且无 request") + } + + customerServiceId := req.Get("customerServiceId").String() + if customerServiceId == "" { + customerServiceId = req.Get("customer_service_id").String() + } + if customerServiceId == "" { + return user, gerror.New("无法获取租户信息:无 token 且无 customerServiceId 参数") + } + + // 3. 直接查询 customer_service_account 表获取 tenantId + filter := bson.M{"customerServiceId": customerServiceId, "isDeleted": false} + var account struct { + TenantId interface{} `bson:"tenantId"` + } + if findErr := db.Collection("customer_service_account").FindOne(ctx, filter).Decode(&account); findErr != nil { + return user, gerror.Newf("通过 customerServiceId 查询租户失败: %v", findErr) + } + + user.TenantId = account.TenantId + user.UserName = customerServiceId + return +} + // Find 查询多条记录 func Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { if err = utils.ValidStructPtr(result); err != nil { return } - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } @@ -135,7 +178,7 @@ func FindOne(ctx context.Context, filter bson.M, result interface{}, collection if err = utils.ValidStructPtr(result); err != nil { return } - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } @@ -198,7 +241,7 @@ func Delete(ctx context.Context, filter bson.M, collection string, opts ...optio err = gerror.New("缺少查询条件") return } - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } @@ -219,7 +262,7 @@ func Update(ctx context.Context, filter bson.M, update bson.M, collection string return } filter["isDeleted"] = false - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } @@ -238,7 +281,7 @@ func Update(ctx context.Context, filter bson.M, update bson.M, collection string // Insert 插入多条记录 func Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } @@ -265,7 +308,7 @@ func Insert(ctx context.Context, documents []interface{}, collection string, opt // Count 查询总数 func Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) { - user, err := utils.GetUserInfo(ctx) + user, err := GetTenantInfo(ctx) if err != nil { return } diff --git a/ragflow/client.go b/ragflow/client.go index 612b2a7..a5f64b2 100644 --- a/ragflow/client.go +++ b/ragflow/client.go @@ -1,7 +1,9 @@ package ragflow import ( + "bytes" "context" + "io" "net" "net/http" "net/url" @@ -12,9 +14,13 @@ import ( "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/net/gclient" ) +// gclient 完全不能用! +// 1. New() 默认 ResponseHeaderTimeout=30s +// 2. Clone() 内部调用 New(),链式调用会重置 Transport +// 3. 必须用原生 http.Client + var ( // globalClient 全局 RAGFlow 客户端(单例,延迟初始化) globalClient *Client @@ -51,15 +57,15 @@ func initClient() { ResponseHeaderTimeout: 180 * time.Second, // 等待响应头超时(关键!) } - // 初始化全局客户端 - httpClient := gclient.New() - httpClient.SetBrowserMode(false) - httpClient.SetHeader("Authorization", "Bearer "+apiKey) - httpClient.SetHeader("Content-Type", "application/json") - httpClient.SetTimeout(180 * time.Second) // RAGFlow AI 推理需要较长时间 + // 使用原生 http.Client(gclient 完全不能用,Clone() 内部调用 New() 会重置 Transport) + httpClient := &http.Client{ + Transport: transport, + Timeout: 0, // 不设置全局超时,由 context 控制 + } - // 设置自定义 Transport - httpClient.Client.Transport = transport + // 验证 Transport 设置 + g.Log().Infof(ctx, "✅ Transport 配置: ResponseHeaderTimeout=%v, MaxIdleConnsPerHost=%d, DisableKeepAlives=%v", + transport.ResponseHeaderTimeout, transport.MaxIdleConnsPerHost, transport.DisableKeepAlives) globalClient = &Client{ BaseURL: strings.TrimSuffix(baseURL, "/"), @@ -90,7 +96,7 @@ func GetGlobalClient() *Client { type Client struct { BaseURL string APIKey string - HTTPClient *gclient.Client // HTTP 客户端 + HTTPClient *http.Client // 原生 HTTP 客户端(gclient 不能用) } // CommonResponse 通用响应结构 @@ -118,39 +124,44 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac reqBody = string(jsonData) } - // 设置 180 秒超时(RAGFlow AI 推理需要较长时间) - reqCtx, cancel := context.WithTimeout(ctx, 180*time.Second) + // 使用独立的 context 设置 300 秒超时(RAGFlow 高并发时响应较慢) + reqCtx, cancel := context.WithTimeout(context.Background(), 300*time.Second) defer cancel() + startTime := time.Now() - var resp *gclient.Response - - switch method { - case "GET": - resp, err = c.HTTPClient.Get(reqCtx, fullURL) - case "POST": - resp, err = c.HTTPClient.Post(reqCtx, fullURL, reqBody) - case "PUT": - resp, err = c.HTTPClient.Put(reqCtx, fullURL, reqBody) - case "DELETE": - resp, err = c.HTTPClient.Delete(reqCtx, fullURL, reqBody) - default: - return gerror.Newf("unsupported method: %s", method) + // 创建请求 + req, err := http.NewRequestWithContext(reqCtx, method, fullURL, bytes.NewReader([]byte(reqBody))) + if err != nil { + return gerror.Newf("create request failed: %v", err) } + // 设置请求头 + req.Header.Set("Authorization", "Bearer "+c.APIKey) + req.Header.Set("Content-Type", "application/json") + + // 发送请求 + g.Log().Infof(ctx, "[RAGFlow HTTP] 发送请求: method=%s, url=%s", method, fullURL) + resp, err := c.HTTPClient.Do(req) + elapsed := time.Since(startTime) if err != nil { - g.Log().Errorf(ctx, "[RAGFlow HTTP] 请求失败: method=%s, url=%s, error=%v", method, fullURL, err) + g.Log().Errorf(ctx, "[RAGFlow HTTP] 请求失败(耗时 %v): method=%s, url=%s, error=%v", elapsed, method, fullURL, err) return gerror.Newf("request failed: %v", err) } - defer resp.Close() + g.Log().Infof(ctx, "[RAGFlow HTTP] 收到响应(耗时 %v): status=%d, url=%s", elapsed, resp.StatusCode, fullURL) + defer resp.Body.Close() - respBody := resp.ReadAll() + // 读取响应 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return gerror.Newf("read response failed: %v", err) + } // 打印响应详情 - g.Log().Debugf(ctx, "[RAGFlow HTTP] 响应: status=%d, body=%s", resp.StatusCode, string(respBody)) + g.Log().Debugf(ctx, "[RAGFlow HTTP] 响应: status=%d, body=%s", resp.StatusCode, respBody) if resp.StatusCode != http.StatusOK { - g.Log().Errorf(ctx, "[RAGFlow HTTP] 非200响应: status=%d, body=%s", resp.StatusCode, string(respBody)) - return gerror.Newf("http status %d: %s", resp.StatusCode, string(respBody)) + g.Log().Errorf(ctx, "[RAGFlow HTTP] 非200响应: status=%d, body=%s", resp.StatusCode, respBody) + return gerror.Newf("http status %d: %s", resp.StatusCode, respBody) } if err = gjson.DecodeTo(respBody, result); err != nil { diff --git a/ragflow/worker_pool.go b/ragflow/worker_pool.go index 06f7878..98a1b48 100644 --- a/ragflow/worker_pool.go +++ b/ragflow/worker_pool.go @@ -2,189 +2,111 @@ package ragflow import ( "context" + "strings" + "time" "gitee.com/red-future---jilin-g/common/redis" "github.com/gogf/gf/v2/os/glog" - "github.com/gogf/gf/v2/os/grpool" ) -// 默认协程池大小 -const defaultPoolSize = 200 +// 默认批量大小(每次从 Redis 读取并发送的消息数) +const defaultBatchSize = 200 -// workerPool 协程池单例(grpool.New 是原型模式,需要变量引用) -var workerPool = grpool.New(defaultPoolSize) - -// WorkerPool RAGFlow 请求处理协程池(封装 grpool) -type WorkerPool struct { - pool *grpool.Pool - size int -} - -// Pool 协程池单例实例(直接引用使用) -var Pool = &WorkerPool{ - pool: workerPool, - size: defaultPoolSize, -} - -// Submit 提交任务到协程池 -// 参数: -// - ctx: 上下文 -// - task: 要执行的任务函数 -// -// 返回:error 提交失败时返回错误 -func (w *WorkerPool) Submit(ctx context.Context, task func(ctx context.Context)) error { - return w.pool.Add(ctx, func(ctx context.Context) { - defer func() { - if r := recover(); r != nil { - glog.Errorf(ctx, "协程池任务执行 panic: %v", r) - } - }() - - task(ctx) - }) -} - -// Size 获取协程池大小 -func (w *WorkerPool) Size() int { - return w.size -} - -// Jobs 获取当前等待执行的任务数量 -func (w *WorkerPool) Jobs() int { - return w.pool.Jobs() -} - -// Close 关闭协程池 -func (w *WorkerPool) Close() { - w.pool.Close() -} - -// WorkerStats 协程池统计信息 -type WorkerStats struct { - PoolSize int // 协程池大小 - Jobs int // 等待执行的任务数 -} - -// Stats 获取协程池统计信息 -func (w *WorkerPool) Stats() WorkerStats { - return WorkerStats{ - PoolSize: w.size, - Jobs: w.pool.Jobs(), - } -} - -// PrintStats 打印协程池统计信息 -func (w *WorkerPool) PrintStats(ctx context.Context) { - stats := w.Stats() - glog.Infof(ctx, "协程池统计 - 池大小: %d, 等待任务: %d", stats.PoolSize, stats.Jobs) -} - -// QueueProcessor Stream 处理器,从 Redis Stream 中取出任务并提交到协程池 +// QueueProcessor Stream 处理器,批量读取消息并发送到 RAGFlow type QueueProcessor struct { - pool *WorkerPool - streamKey string // Stream 键名 - groupName string // 消费者组名称 - consumerName string // 消费者名称 - timeout int64 // 阻塞超时时间(毫秒) - batchSize int64 // 每次读取的消息数量 - stopChan chan struct{} + streamKey string // Stream 键名 + groupName string // 消费者组名称 + consumerName string // 消费者名称 + timeout int64 // 阻塞超时时间(毫秒) + batchSize int64 // 最大并发数(信号量容量) + stopChan chan struct{} // 停止信号 + semaphore chan struct{} // 并发信号量(控制最大并发) handleFunc func(ctx context.Context, message map[string]interface{}) error } // NewQueueProcessor 创建 Stream 处理器 -// 参数: -// - pool: 协程池 -// - streamKey: Redis Stream 键名 -// - groupName: 消费者组名称 -// - consumerName: 消费者名称(唯一标识) -// - timeout: 从 Stream 取消息的超时时间(毫秒) -// - batchSize: 每次读取的消息数量 -// - handleFunc: 消息处理函数 -func NewQueueProcessor(pool *WorkerPool, streamKey, groupName, consumerName string, timeout int64, batchSize int64, handleFunc func(ctx context.Context, message map[string]interface{}) error) *QueueProcessor { +func NewQueueProcessor(streamKey, groupName, consumerName string, timeout, batchSize int64, handleFunc func(ctx context.Context, message map[string]interface{}) error) *QueueProcessor { return &QueueProcessor{ - pool: pool, streamKey: streamKey, groupName: groupName, consumerName: consumerName, timeout: timeout, batchSize: batchSize, stopChan: make(chan struct{}), + semaphore: make(chan struct{}, batchSize), // 信号量容量 = 最大并发数 handleFunc: handleFunc, } } // Start 启动 Stream 处理器 -// 会阻塞运行,持续从 Redis Stream 中取出消息并提交到协程池处理 +// 削峰填谷:每次读取 batchSize 条消息,并发发送,发完立刻读下一批 func (q *QueueProcessor) Start(ctx context.Context) error { - glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 超时: %dms", - q.streamKey, q.groupName, q.consumerName, q.timeout) + glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 批量大小: %d", + q.streamKey, q.groupName, q.consumerName, q.batchSize) + + // 确保 Consumer Group 存在(重试直到成功) + for { + if err := redis.CreateConsumerGroup(ctx, q.streamKey, q.groupName); err != nil { + // BUSYGROUP 表示已存在,不是错误 + if strings.Contains(err.Error(), "BUSYGROUP") { + glog.Debugf(ctx, "Consumer Group 已存在") + break + } + glog.Warningf(ctx, "创建 Consumer Group 失败: %v,1秒后重试", err) + time.Sleep(time.Second) + continue + } + glog.Infof(ctx, "Consumer Group 创建成功") + break + } - loopCount := 0 for { select { case <-q.stopChan: glog.Info(ctx, "Stream 处理器收到停止信号") return nil default: - loopCount++ - if loopCount%10 == 1 { - glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,准备读取消息...", loopCount) - } - - // 从 Redis Stream 中读取消息 - messages, err := q.fetchMessages(ctx) + // 1. 从 Redis Stream 读取一批消息 + messages, err := redis.ReadFromStream(ctx, q.streamKey, q.groupName, q.consumerName, q.batchSize, q.timeout) if err != nil { glog.Errorf(ctx, "从 Stream 读取消息失败: %v", err) continue } - // 没有新消息,继续等待 if len(messages) == 0 { - if loopCount%10 == 1 { - glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,无新消息", loopCount) - } continue } - glog.Infof(ctx, "[DEBUG] 收到 %d 条消息", len(messages)) + glog.Debugf(ctx, "读取 %d 条消息,开始发送", len(messages)) - // 处理每条消息 + // 2. 用信号量控制并发:获取信号量后发送,完成后释放 for _, msg := range messages { - glog.Infof(ctx, "[DEBUG] 处理消息 ID: %s, Values: %+v", msg.ID, msg.Values) - // 提交到协程池处理 - if err := q.submitTask(ctx, msg); err != nil { - glog.Errorf(ctx, "提交任务到协程池失败: %v, 消息ID: %s", err, msg.ID) - } + // 获取信号量(阻塞直到有空位) + q.semaphore <- struct{}{} + go func(m redis.StreamMessage) { + defer func() { <-q.semaphore }() // 完成后释放信号量 + q.processMessage(ctx, m) + }(msg) } + // 3. 立刻读下一批(不等待,信号量自动控制并发数) } } } +// processMessage 处理单条消息(异步执行) +func (q *QueueProcessor) processMessage(ctx context.Context, message redis.StreamMessage) { + // 调用处理函数发送到 RAGFlow + if err := q.handleFunc(ctx, message.Values); err != nil { + glog.Errorf(ctx, "消息处理失败: %v, 消息ID: %s", err, message.ID) + } + + // 无论成功失败都 ACK(避免重复消费) + if err := redis.AckMessage(ctx, q.streamKey, q.groupName, message.ID); err != nil { + glog.Errorf(ctx, "确认消息失败: %v, 消息ID: %s", err, message.ID) + } +} + // Stop 停止队列处理器 func (q *QueueProcessor) Stop() { close(q.stopChan) } - -// fetchMessages 从 Redis Stream 中读取消息 -func (q *QueueProcessor) fetchMessages(ctx context.Context) ([]redis.StreamMessage, error) { - // 从消费者组读取消息 - return redis.ReadFromStream(ctx, q.streamKey, q.groupName, q.consumerName, q.batchSize, q.timeout) -} - -// submitTask 将消息处理任务提交到协程池 -func (q *QueueProcessor) submitTask(ctx context.Context, message redis.StreamMessage) error { - return q.pool.Submit(ctx, func(ctx context.Context) { - // 处理消息 - if err := q.handleFunc(ctx, message.Values); err != nil { - glog.Errorf(ctx, "处理消息失败: %v, 消息ID: %s", err, message.ID) - return - } - - // 处理成功后确认消息 - if err := redis.AckMessage(ctx, q.streamKey, q.groupName, message.ID); err != nil { - glog.Errorf(ctx, "确认消息失败: %v, 消息ID: %s", err, message.ID) - } else { - glog.Debugf(ctx, "消息处理完成并已确认: %s", message.ID) - } - }) -} diff --git a/redis/redis.go b/redis/redis.go index a74df47..7156952 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -95,6 +95,13 @@ func AddToStream(ctx context.Context, streamKey string, msg interface{}) (messag return } +// CreateConsumerGroup 创建消费者组(如果不存在) +// XGROUP CREATE streamKey groupName $ MKSTREAM +func CreateConsumerGroup(ctx context.Context, streamKey, groupName string) error { + _, err := redisClient.Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "$", "MKSTREAM") + return err +} + // ReadFromStream 从 Stream 读取消息(消费者组模式) // 使用 gredis Do() 方法执行 XREADGROUP 命令 func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) { @@ -447,3 +454,24 @@ func DelSessionCache(ctx context.Context, userId string) error { _, err := redisClient.Del(ctx, key) return err } + +// TryLock 尝试获取分布式锁(非阻塞) +// key: 锁的键名 +// expireSeconds: 锁的过期时间(秒),防止死锁 +// 返回 true 表示获取成功,false 表示锁已被其他节点持有 +func TryLock(ctx context.Context, key string, expireSeconds int) bool { + // SET key value NX EX expireSeconds + result, err := redisClient.Do(ctx, "SET", key, gtime.Now().String(), "NX", "EX", expireSeconds) + if err != nil { + glog.Errorf(ctx, "获取分布式锁失败: %v", err) + return false + } + return result.String() == "OK" +} + +// Unlock 释放分布式锁 +func Unlock(ctx context.Context, key string) { + if _, err := redisClient.Del(ctx, key); err != nil { + glog.Errorf(ctx, "释放分布式锁失败: %v", err) + } +}