From 4b2b5e6177c53bb296ef7723221b0731bc76bbbf Mon Sep 17 00:00:00 2001 From: Cold <16419454+cold502@user.noreply.gitee.com> Date: Sat, 6 Dec 2025 18:04:29 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BA=86=E4=B8=80=E4=B8=8B?= =?UTF-8?q?=20rag=E7=9A=84=E6=96=B9=E6=B3=95,=20=E4=BD=BF=E7=94=A8=20gofra?= =?UTF-8?q?me=E7=9A=84=E6=A1=86=E6=9E=B6,=20=E8=BF=98=E6=9C=89redis?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rabbitmq/client.go | 33 +++---- rabbitmq/consumer.go | 16 ++-- rabbitmq/publisher.go | 28 +++--- rabbitmq/setup.go | 22 ++--- ragflow/agent.go | 41 ++++----- ragflow/chat.go | 29 +++--- ragflow/chunk.go | 37 ++++---- ragflow/client.go | 38 ++++---- ragflow/dataset.go | 29 +++--- ragflow/document.go | 31 +++---- ragflow/openai.go | 32 +++---- ragflow/session.go | 29 +++--- ragflow/system.go | 16 ++-- ragflow/worker_pool.go | 12 +++ redis/redis.go | 200 ++++++++++++++++++++++++++++------------- redis/types.go | 65 ++++++++++++++ 16 files changed, 398 insertions(+), 260 deletions(-) diff --git a/rabbitmq/client.go b/rabbitmq/client.go index 4c0a05f..68eda88 100644 --- a/rabbitmq/client.go +++ b/rabbitmq/client.go @@ -2,11 +2,12 @@ package rabbitmq import ( "context" - "fmt" "sync" "time" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" amqp "github.com/rabbitmq/amqp091-go" ) @@ -33,13 +34,7 @@ func Init(ctx context.Context, cfg *Config) error { var err error once.Do(func() { // 构建连接字符串 - url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s", - cfg.Username, - cfg.Password, - cfg.Host, - cfg.Port, - cfg.VHost, - ) + url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost // 创建连接 conn, err = amqp.Dial(url) @@ -89,7 +84,7 @@ func GetChannel() (*amqp.Channel, error) { defer mu.RUnlock() if channel == nil || channel.IsClosed() { - return nil, fmt.Errorf("RabbitMQ Channel 未初始化或已关闭") + return nil, gerror.New("RabbitMQ Channel 未初始化或已关闭") } return channel, nil @@ -101,7 +96,7 @@ func GetConnection() (*amqp.Connection, error) { defer mu.RUnlock() if conn == nil || conn.IsClosed() { - return nil, fmt.Errorf("RabbitMQ 连接未初始化或已关闭") + return nil, gerror.New("RabbitMQ 连接未初始化或已关闭") } return conn, nil @@ -160,13 +155,7 @@ func reconnect(ctx context.Context) { VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(), } - url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s", - cfg.Username, - cfg.Password, - cfg.Host, - cfg.Port, - cfg.VHost, - ) + url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost var err error conn, err = amqp.Dial(url) @@ -190,7 +179,7 @@ func reconnect(ctx context.Context) { } // Close 关闭连接 -func Close(ctx context.Context) error { +func Close(ctx context.Context) (err error) { mu.Lock() defer mu.Unlock() @@ -201,21 +190,21 @@ func Close(ctx context.Context) error { } if channel != nil { - if err := channel.Close(); err != nil { + if err = channel.Close(); err != nil { g.Log().Errorf(ctx, "关闭 RabbitMQ Channel 失败: %v", err) } channel = nil } if conn != nil { - if err := conn.Close(); err != nil { + if err = conn.Close(); err != nil { g.Log().Errorf(ctx, "关闭 RabbitMQ 连接失败: %v", err) - return err + return } conn = nil } watcherStarted = false g.Log().Info(ctx, "RabbitMQ 连接已关闭") - return nil + return } diff --git a/rabbitmq/consumer.go b/rabbitmq/consumer.go index 1b0746d..4fb5a2f 100644 --- a/rabbitmq/consumer.go +++ b/rabbitmq/consumer.go @@ -2,9 +2,9 @@ package rabbitmq import ( "context" - "encoding/json" - "fmt" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" amqp "github.com/rabbitmq/amqp091-go" ) @@ -74,7 +74,7 @@ func NewConsumer(queue string, handler MessageHandler, opts ...ConsumerOption) * } // Start 启动消费者 -func (c *Consumer) Start(ctx context.Context) error { +func (c *Consumer) Start(ctx context.Context) (err error) { // 创建可取消的 context workerCtx, cancel := context.WithCancel(ctx) c.cancel = cancel @@ -90,7 +90,7 @@ func (c *Consumer) Start(ctx context.Context) error { false, // global: false 表示仅应用于当前 channel ) if err != nil { - return fmt.Errorf("设置 QoS 失败: %v", err) + return gerror.Newf("设置 QoS 失败: %v", err) } // 开始消费 @@ -104,7 +104,7 @@ func (c *Consumer) Start(ctx context.Context) error { nil, // args ) if err != nil { - return fmt.Errorf("开始消费失败: %v", err) + return gerror.Newf("开始消费失败: %v", err) } g.Log().Infof(ctx, "消费者已启动: queue=%s, prefetch=%d, workers=%d", @@ -115,7 +115,7 @@ func (c *Consumer) Start(ctx context.Context) error { go c.worker(workerCtx, i, msgs) } - return nil + return } // worker 工作协程 @@ -168,8 +168,8 @@ func StartTypedConsumer[T any]( // 包装处理函数 wrappedHandler := func(ctx context.Context, body []byte) error { var msg T - if err := json.Unmarshal(body, &msg); err != nil { - return fmt.Errorf("反序列化消息失败: %v", err) + if err := gjson.DecodeTo(body, &msg); err != nil { + return gerror.Newf("反序列化消息失败: %v", err) } return handler(ctx, &msg) diff --git a/rabbitmq/publisher.go b/rabbitmq/publisher.go index 096744b..182622b 100644 --- a/rabbitmq/publisher.go +++ b/rabbitmq/publisher.go @@ -2,9 +2,9 @@ package rabbitmq import ( "context" - "encoding/json" - "fmt" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" amqp "github.com/rabbitmq/amqp091-go" ) @@ -24,16 +24,16 @@ func NewPublisher(exchange, routingKey string) *Publisher { } // Publish 发布消息 -func (p *Publisher) Publish(ctx context.Context, message interface{}) error { +func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error) { ch, err := GetChannel() if err != nil { return err } // 序列化消息 - body, err := json.Marshal(message) + body, err := gjson.Encode(message) if err != nil { - return fmt.Errorf("消息序列化失败: %v", err) + return gerror.Newf("消息序列化失败: %v", err) } // 发布消息 @@ -59,21 +59,21 @@ func (p *Publisher) Publish(ctx context.Context, message interface{}) error { g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s", p.exchange, p.routingKey) - return nil + return } // PublishDelayed 发布延时消息 // delaySeconds: 延时秒数 -func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) error { +func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) (err error) { ch, err := GetChannel() if err != nil { return err } // 序列化消息 - body, err := json.Marshal(message) + body, err := gjson.Encode(message) if err != nil { - return fmt.Errorf("消息序列化失败: %v", err) + return gerror.Newf("消息序列化失败: %v", err) } // 发布延时消息(需要 rabbitmq_delayed_message_exchange 插件) @@ -102,13 +102,13 @@ func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, del g.Log().Debugf(ctx, "延时消息发布成功: exchange=%s, routingKey=%s, delay=%ds", p.exchange, p.routingKey, delaySeconds) - return nil + return } // PublishBatch 批量发布消息 -func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) error { +func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) (err error) { if len(messages) == 0 { - return nil + return } ch, err := GetChannel() @@ -117,7 +117,7 @@ func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) er } for i, message := range messages { - body, err := json.Marshal(message) + body, err := gjson.Encode(message) if err != nil { g.Log().Errorf(ctx, "消息 %d 序列化失败: %v", i, err) continue @@ -143,5 +143,5 @@ func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) er } g.Log().Infof(ctx, "批量发布完成: 共 %d 条消息", len(messages)) - return nil + return } diff --git a/rabbitmq/setup.go b/rabbitmq/setup.go index bc8bf59..47793fd 100644 --- a/rabbitmq/setup.go +++ b/rabbitmq/setup.go @@ -2,8 +2,8 @@ package rabbitmq import ( "context" - "fmt" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" amqp "github.com/rabbitmq/amqp091-go" ) @@ -35,7 +35,7 @@ type BindingConfig struct { } // DeclareQueue 声明队列 -func DeclareQueue(ctx context.Context, cfg *QueueConfig) error { +func DeclareQueue(ctx context.Context, cfg *QueueConfig) (err error) { ch, err := GetChannel() if err != nil { return err @@ -56,11 +56,11 @@ func DeclareQueue(ctx context.Context, cfg *QueueConfig) error { } g.Log().Infof(ctx, "队列声明成功: %s", cfg.Name) - return nil + return } // DeclareExchange 声明 Exchange -func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) error { +func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) (err error) { ch, err := GetChannel() if err != nil { return err @@ -82,11 +82,11 @@ func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) error { } g.Log().Infof(ctx, "Exchange 声明成功: %s (type=%s)", cfg.Name, cfg.Type) - return nil + return } // BindQueue 绑定队列到 Exchange -func BindQueue(ctx context.Context, cfg *BindingConfig) error { +func BindQueue(ctx context.Context, cfg *BindingConfig) (err error) { ch, err := GetChannel() if err != nil { return err @@ -108,7 +108,7 @@ func BindQueue(ctx context.Context, cfg *BindingConfig) error { g.Log().Infof(ctx, "队列绑定成功: queue=%s → exchange=%s (routingKey=%s)", cfg.Queue, cfg.Exchange, cfg.RoutingKey) - return nil + return } // SetupDelayExchange 设置延时 Exchange(需要 rabbitmq_delayed_message_exchange 插件) @@ -165,9 +165,9 @@ func SetupQueueWithDLX(ctx context.Context, queueName, dlxExchange, dlxRoutingKe } // SetupBasicTopology 设置基础拓扑(适用于小红书客服场景) -func SetupBasicTopology(ctx context.Context) error { +func SetupBasicTopology(ctx context.Context) (err error) { // 1. 声明普通 Exchange - err := DeclareExchange(ctx, &ExchangeConfig{ + err = DeclareExchange(ctx, &ExchangeConfig{ Name: "ragflow_exchange", Type: "direct", Durable: true, @@ -179,7 +179,7 @@ func SetupBasicTopology(ctx context.Context) error { // 2. 声明延时 Exchange err = SetupDelayExchange(ctx, "delay_exchange") if err != nil { - return fmt.Errorf("延时 Exchange 声明失败(可能未安装插件): %v", err) + return gerror.Newf("延时 Exchange 声明失败(可能未安装插件): %v", err) } // 3. 声明死信队列 @@ -227,5 +227,5 @@ func SetupBasicTopology(ctx context.Context) error { } g.Log().Info(ctx, "RabbitMQ 拓扑结构设置完成") - return nil + return } diff --git a/ragflow/agent.go b/ragflow/agent.go index 4ce263a..e389e59 100644 --- a/ragflow/agent.go +++ b/ragflow/agent.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // Agent AGENT 管理 @@ -56,44 +57,44 @@ type ListAgentsRes struct { // CreateAgent 创建 Agent // POST /api/v1/agents -func (c *Client) CreateAgent(ctx context.Context, req *CreateAgentReq) error { +func (c *Client) CreateAgent(ctx context.Context, req *CreateAgentReq) (err error) { var res CommonResponse - if err := c.request(ctx, "POST", "/api/v1/agents", req, &res); err != nil { - return fmt.Errorf("create agent failed: %w", err) + if err = c.request(ctx, "POST", "/api/v1/agents", req, &res); err != nil { + return gerror.Newf("create agent failed: %v", err) } if !res.IsSuccess() { - return fmt.Errorf("create agent failed: %s", res.Message) + return gerror.Newf("create agent failed: %s", res.Message) } - return nil + return } // UpdateAgent 更新 Agent // PUT /api/v1/agents/{agent_id} -func (c *Client) UpdateAgent(ctx context.Context, agentID string, req *UpdateAgentReq) error { - path := fmt.Sprintf("/api/v1/agents/%s", agentID) +func (c *Client) UpdateAgent(ctx context.Context, agentID string, req *UpdateAgentReq) (err error) { + path := "/api/v1/agents/" + agentID var res CommonResponse - if err := c.request(ctx, "PUT", path, req, &res); err != nil { - return fmt.Errorf("update agent failed: %w", err) + if err = c.request(ctx, "PUT", path, req, &res); err != nil { + return gerror.Newf("update agent failed: %v", err) } if !res.IsSuccess() { - return fmt.Errorf("update agent failed: %s", res.Message) + return gerror.Newf("update agent failed: %s", res.Message) } - return nil + return } // DeleteAgent 删除 Agent // DELETE /api/v1/agents/{agent_id} -func (c *Client) DeleteAgent(ctx context.Context, agentID string) error { - path := fmt.Sprintf("/api/v1/agents/%s", agentID) +func (c *Client) DeleteAgent(ctx context.Context, agentID string) (err error) { + path := "/api/v1/agents/" + agentID var res CommonResponse // 官方文档要求传空对象,不是 nil - if err := c.request(ctx, "DELETE", path, map[string]interface{}{}, &res); err != nil { - return fmt.Errorf("delete agent failed: %w", err) + if err = c.request(ctx, "DELETE", path, map[string]interface{}{}, &res); err != nil { + return gerror.Newf("delete agent failed: %v", err) } if !res.IsSuccess() { - return fmt.Errorf("delete agent failed: %s", res.Message) + return gerror.Newf("delete agent failed: %s", res.Message) } - return nil + return } // ListAgents 列出 Agent @@ -131,10 +132,10 @@ func (c *Client) ListAgents(ctx context.Context, req *ListAgentsReq) (*ListAgent var res ListAgentsRes if err := c.request(ctx, "GET", path, nil, &res); err != nil { - return nil, fmt.Errorf("list agents failed: %w", err) + return nil, gerror.Newf("list agents failed: %v", err) } if res.Code != 0 { - return nil, fmt.Errorf("list agents failed: code=%d", res.Code) + return nil, gerror.Newf("list agents failed: code=%d", res.Code) } return &res, nil } diff --git a/ragflow/chat.go b/ragflow/chat.go index 6f65335..7f0acb5 100644 --- a/ragflow/chat.go +++ b/ragflow/chat.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // 聊天助手管理 @@ -104,7 +105,7 @@ func (c *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*Chat, err return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("create chat failed: %s", res.Msg) + return nil, gerror.Newf("create chat failed: %s", res.Msg) } return res.Data, nil } @@ -144,33 +145,33 @@ func (c *Client) ListChats(ctx context.Context, req *ListChatsReq) (*ListChatsRe return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("list chats failed: code=%d", res.Code) + return nil, gerror.Newf("list chats failed: code=%d", res.Code) } return &res, nil } // DeleteChats 删除聊天助手 -func (c *Client) DeleteChats(ctx context.Context, ids []string) error { +func (c *Client) DeleteChats(ctx context.Context, ids []string) (err error) { req := DeleteChatsReq{Ids: ids} var res CommonResponse - if err := c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil { - return err + if err = c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("delete chats failed: %s", res.Message) + return gerror.Newf("delete chats failed: %s", res.Message) } - return nil + return } // UpdateChat 更新聊天助手 -func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) error { +func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) (err error) { var res CommonResponse - path := fmt.Sprintf("/api/v1/chats/%s", id) - if err := c.request(ctx, "PUT", path, req, &res); err != nil { - return err + path := "/api/v1/chats/" + id + if err = c.request(ctx, "PUT", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("update chat failed: %s", res.Message) + return gerror.Newf("update chat failed: %s", res.Message) } - return nil + return } diff --git a/ragflow/chunk.go b/ragflow/chunk.go index 7e0b66e..030d21a 100644 --- a/ragflow/chunk.go +++ b/ragflow/chunk.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // 数据集内知识块管理 @@ -90,7 +91,7 @@ type RetrieveChunksRes struct { // AddChunk 添加知识块 func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req *AddChunkReq) (*Chunk, error) { - path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId) + path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks" var res struct { Code int `json:"code"` Data struct { @@ -102,14 +103,14 @@ func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("add chunk failed: %s", res.Msg) + return nil, gerror.Newf("add chunk failed: %s", res.Msg) } return res.Data.Chunk, nil } // ListChunks 列出知识块 func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, req *ListChunksReq) (*ListChunksRes, error) { - path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId) + path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks" params := map[string]interface{}{} if req.Keywords != "" { params["keywords"] = req.Keywords @@ -134,36 +135,36 @@ func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, r return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("list chunks failed: code=%d", res.Code) + return nil, gerror.Newf("list chunks failed: code=%d", res.Code) } return &res, nil } // DeleteChunks 删除知识块 -func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) error { +func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) (err error) { req := DeleteChunksReq{ChunkIds: chunkIds} var res CommonResponse - path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId) - if err := c.request(ctx, "DELETE", path, req, &res); err != nil { - return err + path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks" + if err = c.request(ctx, "DELETE", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("delete chunks failed: %s", res.Message) + return gerror.Newf("delete chunks failed: %s", res.Message) } - return nil + return } // UpdateChunk 更新知识块 -func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) error { +func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) (err error) { var res CommonResponse - path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks/%s", datasetId, documentId, chunkId) - if err := c.request(ctx, "PUT", path, req, &res); err != nil { - return err + path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks/" + chunkId + if err = c.request(ctx, "PUT", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("update chunk failed: %s", res.Message) + return gerror.Newf("update chunk failed: %s", res.Message) } - return nil + return } // RetrieveChunks 检索知识块 @@ -173,7 +174,7 @@ func (c *Client) RetrieveChunks(ctx context.Context, req *RetrieveChunksReq) (*R return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("retrieve chunks failed: code=%d", res.Code) + return nil, gerror.Newf("retrieve chunks failed: code=%d", res.Code) } return &res, nil } diff --git a/ragflow/client.go b/ragflow/client.go index 0633dfc..6b4ccc5 100644 --- a/ragflow/client.go +++ b/ragflow/client.go @@ -2,13 +2,12 @@ package ragflow import ( "context" - "encoding/json" - "fmt" - "io" "net/http" "net/url" "strings" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/gclient" ) @@ -33,7 +32,7 @@ func init() { // 初始化全局客户端 httpClient := gclient.New() - httpClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + httpClient.SetHeader("Authorization", "Bearer "+apiKey) httpClient.SetHeader("Content-Type", "application/json") globalClient = &Client{ @@ -79,20 +78,19 @@ func (r *CommonResponse) IsSuccess() bool { } // request 发送 HTTP 请求 -func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) error { +func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) (err error) { fullURL := c.BaseURL + path - var reqBody io.Reader + var reqBody string if body != nil { - jsonData, err := json.Marshal(body) + jsonData, err := gjson.Encode(body) if err != nil { - return fmt.Errorf("marshal request body failed: %w", err) + return gerror.Newf("marshal request body failed: %v", err) } - reqBody = strings.NewReader(string(jsonData)) + reqBody = string(jsonData) } var resp *gclient.Response - var err error switch method { case "GET": @@ -104,28 +102,24 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac case "DELETE": resp, err = c.HTTPClient.Delete(ctx, fullURL, reqBody) default: - return fmt.Errorf("unsupported method: %s", method) + return gerror.Newf("unsupported method: %s", method) } if err != nil { - return fmt.Errorf("http request failed: %w", err) + return gerror.Newf("http request failed: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("http request failed with status: %d", resp.StatusCode) + return gerror.Newf("http request failed with status: %d", resp.StatusCode) } respBody := resp.ReadAll() - if err != nil { - return fmt.Errorf("read response body failed: %w", err) + if err = gjson.DecodeTo(respBody, result); err != nil { + return gerror.Newf("unmarshal response failed: %v", err) } - if err := json.Unmarshal(respBody, result); err != nil { - return fmt.Errorf("unmarshal response failed: %w", err) - } - - return nil + return } // buildQueryString 构建查询字符串 @@ -134,9 +128,9 @@ func buildQueryString(params map[string]interface{}) string { return "" } - var parts []string + parts := make([]string, 0, len(params)) for k, v := range params { - parts = append(parts, fmt.Sprintf("%s=%v", url.QueryEscape(k), url.QueryEscape(fmt.Sprintf("%v", v)))) + parts = append(parts, url.QueryEscape(k)+"="+url.QueryEscape(g.NewVar(v).String())) } return strings.Join(parts, "&") } diff --git a/ragflow/dataset.go b/ragflow/dataset.go index 83162b7..431b332 100644 --- a/ragflow/dataset.go +++ b/ragflow/dataset.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // 数据集管理 @@ -90,7 +91,7 @@ func (c *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*Dat return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("create dataset failed: %s", res.Msg) + return nil, gerror.Newf("create dataset failed: %s", res.Msg) } return res.Data, nil } @@ -134,33 +135,33 @@ func (c *Client) ListDatasets(ctx context.Context, req *ListDatasetsReq) (*ListD return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("list datasets failed: code=%d", res.Code) + return nil, gerror.Newf("list datasets failed: code=%d", res.Code) } return &res, nil } // DeleteDataset 删除数据集 -func (c *Client) DeleteDataset(ctx context.Context, ids []string) error { +func (c *Client) DeleteDataset(ctx context.Context, ids []string) (err error) { req := DeleteDatasetsReq{Ids: ids} var res CommonResponse - if err := c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil { - return err + if err = c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("delete dataset failed: %s", res.Message) + return gerror.Newf("delete dataset failed: %s", res.Message) } - return nil + return } // UpdateDataset 更新数据集 -func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) error { +func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) (err error) { var res CommonResponse - path := fmt.Sprintf("/api/v1/datasets/%s", id) - if err := c.request(ctx, "PUT", path, req, &res); err != nil { - return err + path := "/api/v1/datasets/" + id + if err = c.request(ctx, "PUT", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("update dataset failed: %s", res.Message) + return gerror.Newf("update dataset failed: %s", res.Message) } - return nil + return } diff --git a/ragflow/document.go b/ragflow/document.go index 60d6192..c08f32b 100644 --- a/ragflow/document.go +++ b/ragflow/document.go @@ -2,8 +2,9 @@ package ragflow import ( "context" - "fmt" "strings" + + "github.com/gogf/gf/v2/errors/gerror" ) // 数据集内文件管理 @@ -70,7 +71,7 @@ type DeleteDocumentsReq struct { // ListDocuments 列出文档 func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListDocumentsReq) (*ListDocumentsRes, error) { - path := fmt.Sprintf("/api/v1/datasets/%s/documents", datasetId) + path := "/api/v1/datasets/" + datasetId + "/documents" params := map[string]interface{}{} if req.Page > 0 { params["page"] = req.Page @@ -111,16 +112,14 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD // 处理数组参数:suffix(文件后缀过滤) // API 要求多个值时重复参数名,如:suffix=pdf&suffix=txt - // 这里使用 fmt.Sprintf 来构造每个参数值 for _, suffix := range req.Suffix { - queryParts = append(queryParts, fmt.Sprintf("suffix=%s", suffix)) + queryParts = append(queryParts, "suffix="+suffix) } // 处理数组参数:run(处理状态过滤) // 支持数字格式("0"-"4")或文本格式("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL") - // 这里使用 fmt.Sprintf 来构造每个参数值 for _, run := range req.Run { - queryParts = append(queryParts, fmt.Sprintf("run=%s", run)) + queryParts = append(queryParts, "run="+run) } // 构造最终请求路径 @@ -134,7 +133,7 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("list documents failed: code=%d", res.Code) + return nil, gerror.Newf("list documents failed: code=%d", res.Code) } return &res, nil } @@ -142,23 +141,21 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD // UploadDocument 上传文档 // 注意:此方法需要特殊处理 multipart/form-data,目前的 request 方法可能不支持 // 我们需要扩展 request 方法或在此处单独实现 -func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) error { +func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) (err error) { // TODO: 实现文件上传逻辑,需要使用 gclient 的 UploadFile 功能 - // 由于 request 方法封装了 JSON 处理,这里可能需要绕过 request 方法直接使用 c.Client - // 暂时留空或仅做简单提示,待完善 Client 封装以支持文件上传 - return fmt.Errorf("upload document not implemented yet") + return gerror.New("upload document not implemented yet") } // DeleteDocument 删除文档 -func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) error { +func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) (err error) { req := DeleteDocumentsReq{Ids: ids} var res CommonResponse - path := fmt.Sprintf("/api/v1/datasets/%s/documents", datasetId) - if err := c.request(ctx, "DELETE", path, req, &res); err != nil { - return err + path := "/api/v1/datasets/" + datasetId + "/documents" + if err = c.request(ctx, "DELETE", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("delete document failed: %s", res.Message) + return gerror.Newf("delete document failed: %s", res.Message) } - return nil + return } diff --git a/ragflow/openai.go b/ragflow/openai.go index acaa2ff..4218592 100644 --- a/ragflow/openai.go +++ b/ragflow/openai.go @@ -2,8 +2,9 @@ package ragflow import ( "context" - "encoding/json" - "fmt" + + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/errors/gerror" ) // OpenAICompatibleAPI 与 OpenAI 兼容的 API @@ -64,11 +65,11 @@ type ChatCompletionChunk struct { // CreateChatCompletion 创建聊天补全(与聊天助手) // POST /api/v1/chats_openai/{chat_id}/chat/completions func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) { - path := fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID) + path := "/api/v1/chats_openai/" + chatID + "/chat/completions" var resp ChatCompletionResponse if err := c.request(ctx, "POST", path, req, &resp); err != nil { - return nil, fmt.Errorf("create chat completion failed: %w", err) + return nil, gerror.Newf("create chat completion failed: %v", err) } return &resp, nil @@ -77,11 +78,11 @@ func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *C // CreateAgentCompletion 创建 Agent 补全 // POST /api/v1/agents_openai/{agent_id}/chat/completions func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) { - path := fmt.Sprintf("/api/v1/agents_openai/%s/chat/completions", agentID) + path := "/api/v1/agents_openai/" + agentID + "/chat/completions" var resp ChatCompletionResponse if err := c.request(ctx, "POST", path, req, &resp); err != nil { - return nil, fmt.Errorf("create agent completion failed: %w", err) + return nil, gerror.Newf("create agent completion failed: %v", err) } return &resp, nil @@ -91,31 +92,26 @@ func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req // 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口 func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) { req.Stream = true - _ = fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID) - // TODO: 实现流式读取逻辑 - return nil, fmt.Errorf("stream mode not implemented yet") + return nil, gerror.New("stream mode not implemented yet") } // StreamReader 流式响应读取器 type StreamReader struct { - decoder *json.Decoder - close func() error + _ *gjson.Json // TODO: 实现流式读取时使用 + close func() error } // ReadChunk 读取下一个响应块 +// TODO: 实现流式读取逻辑 func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) { - var chunk ChatCompletionChunk - if err := sr.decoder.Decode(&chunk); err != nil { - return nil, err - } - return &chunk, nil + return nil, gerror.New("stream mode not implemented yet") } // Close 关闭流 -func (sr *StreamReader) Close() error { +func (sr *StreamReader) Close() (err error) { if sr.close != nil { return sr.close() } - return nil + return } diff --git a/ragflow/session.go b/ragflow/session.go index 832082a..ee82b5d 100644 --- a/ragflow/session.go +++ b/ragflow/session.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // 会话管理 @@ -76,7 +77,7 @@ type ChatCompletionRes struct { // CreateSession 创建会话 func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) { - path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId) + path := "/api/v1/chats/" + chatId + "/sessions" var res struct { Code int `json:"code"` Data *Session `json:"data"` @@ -86,14 +87,14 @@ func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSe return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("create session failed: %s", res.Msg) + return nil, gerror.Newf("create session failed: %s", res.Msg) } return res.Data, nil } // ListSessions 列出会话 func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) { - path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId) + path := "/api/v1/chats/" + chatId + "/sessions" params := map[string]interface{}{} if req.Page > 0 { params["page"] = req.Page @@ -129,40 +130,40 @@ func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessi return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("list sessions failed: code=%d", res.Code) + return nil, gerror.Newf("list sessions failed: code=%d", res.Code) } return &res, nil } // DeleteSessions 删除会话 -func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) error { +func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) (err error) { req := DeleteSessionsReq{Ids: ids} var res CommonResponse - path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId) - if err := c.request(ctx, "DELETE", path, req, &res); err != nil { - return err + path := "/api/v1/chats/" + chatId + "/sessions" + if err = c.request(ctx, "DELETE", path, req, &res); err != nil { + return } if !res.IsSuccess() { - return fmt.Errorf("delete sessions failed: %s", res.Message) + return gerror.Newf("delete sessions failed: %s", res.Message) } - return nil + return } // ChatCompletion 对话 (目前仅支持非流式) func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) { - path := fmt.Sprintf("/api/v1/chats/%s/completions", chatId) + path := "/api/v1/chats/" + chatId + "/completions" var res ChatCompletionRes // 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式 if req.Stream { - return nil, fmt.Errorf("stream mode not supported yet") + return nil, gerror.New("stream mode not supported yet") } if err := c.request(ctx, "POST", path, req, &res); err != nil { return nil, err } if res.Code != 0 { - return nil, fmt.Errorf("chat completion failed: code=%d", res.Code) + return nil, gerror.Newf("chat completion failed: code=%d", res.Code) } return &res, nil } diff --git a/ragflow/system.go b/ragflow/system.go index f294c91..1a7da73 100644 --- a/ragflow/system.go +++ b/ragflow/system.go @@ -2,7 +2,8 @@ package ragflow import ( "context" - "fmt" + + "github.com/gogf/gf/v2/errors/gerror" ) // System 系统管理 @@ -10,11 +11,11 @@ import ( // HealthStatus 健康状态 type HealthStatus struct { - DB string `json:"db"` // "ok" 或 "nok" - Redis string `json:"redis"` // "ok" 或 "nok" - DocEngine string `json:"doc_engine"` // "ok" 或 "nok" - Storage string `json:"storage"` // "ok" 或 "nok" - Status string `json:"status"` // 整体状态: "ok" 或 "nok" + DB string `json:"db"` // "ok" 或 "nok" + Redis string `json:"redis"` // "ok" 或 "nok" + DocEngine string `json:"doc_engine"` // "ok" 或 "nok" + Storage string `json:"storage"` // "ok" 或 "nok" + Status string `json:"status"` // 整体状态: "ok" 或 "nok" Meta map[string]interface{} `json:"_meta,omitempty"` // 详细错误信息 } @@ -23,7 +24,7 @@ type HealthStatus struct { func (c *Client) CheckHealth(ctx context.Context) (*HealthStatus, error) { var status HealthStatus if err := c.request(ctx, "GET", "/v1/system/healthz", nil, &status); err != nil { - return nil, fmt.Errorf("check health failed: %w", err) + return nil, gerror.Newf("check health failed: %v", err) } return &status, nil } @@ -36,4 +37,3 @@ func (c *Client) IsHealthy(ctx context.Context) (bool, error) { } return status.Status == "ok", nil } - diff --git a/ragflow/worker_pool.go b/ragflow/worker_pool.go index 47767bf..b836ef8 100644 --- a/ragflow/worker_pool.go +++ b/ragflow/worker_pool.go @@ -149,12 +149,18 @@ func (q *QueueProcessor) Start(ctx context.Context) error { glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 超时: %dms", q.streamKey, q.groupName, q.consumerName, q.timeout) + loopCount := 0 for { select { case <-q.stopChan: glog.Info(ctx, "Stream 处理器收到停止信号") return nil default: + loopCount++ + if loopCount%10 == 1 { + glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,准备读取消息...", loopCount) + } + // 从 Redis Stream 中读取消息 messages, err := q.fetchMessages(ctx) if err != nil { @@ -164,11 +170,17 @@ func (q *QueueProcessor) Start(ctx context.Context) error { // 没有新消息,继续等待 if len(messages) == 0 { + if loopCount%10 == 1 { + glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,无新消息", loopCount) + } continue } + glog.Infof(ctx, "[DEBUG] 收到 %d 条消息", len(messages)) + // 处理每条消息 for _, msg := range messages { + glog.Infof(ctx, "[DEBUG] 处理消息 ID: %s, Values: %+v", msg.ID, msg.Values) // 提交到协程池处理 if err := q.submitTask(ctx, msg); err != nil { glog.Errorf(ctx, "提交任务到协程池失败: %v, 消息ID: %s", err, msg.ID) diff --git a/redis/redis.go b/redis/redis.go index 1e5ca10..186d2ad 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -3,33 +3,50 @@ package redis import ( "context" "strings" + "sync" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) -// GRedisClient GoFrame gredis 客户端,统一使用(懒加载) -var GRedisClient *gredis.Redis +var ( + // redisClient 单例 Redis 客户端 + redisClient *gredis.Redis + // redisOnce 确保只初始化一次 + redisOnce sync.Once + // RedisClient 兼容导出(供 mongo.go 使用) + // 注意:这是一个指向单例的指针,首次调用 GetRedisClient() 后生效 + RedisClient *gredis.Redis +) -// RedisClient GRedisClient 的别名,保持向后兼容 -var RedisClient *gredis.Redis - -// GetRedisClient 获取 Redis 客户端(懒加载) +// GetRedisClient 获取 Redis 客户端(单例模式) func GetRedisClient() *gredis.Redis { - if GRedisClient == nil { - GRedisClient = g.Redis() - RedisClient = GRedisClient - } - return GRedisClient + redisOnce.Do(func() { + redisClient = g.Redis() + RedisClient = redisClient // 同步更新兼容导出 + }) + return redisClient +} + +// init 包初始化时自动初始化 Redis 客户端 +func init() { + GetRedisClient() } // Stream 和消费者组常量 const ( // RAGFlow 请求 Stream Key RAGFlowRequestStreamKey = "ragflow:request:stream" - // RAGFlow 消费者组名称 + // RAGFlow 响应 Stream Key + RAGFlowResponseStreamKey = "ragflow:response:stream" + // RAGFlow 请求消费者组名称 + RAGFlowRequestConsumerGroup = "ragflow:request:consumer:group" + // RAGFlow 响应消费者组名称 + RAGFlowResponseConsumerGroup = "ragflow:response:consumer:group" + // RAGFlow 消费者组名称(兼容旧代码) RAGFlowConsumerGroup = "ragflow:consumer:group" // 会话最后活跃时间 Key 前缀 SessionLastActiveKeyPrefix = "ragflow:session:" @@ -79,6 +96,9 @@ func AddToStream(ctx context.Context, streamKey string, values map[string]interf // ReadFromStream 从 Stream 读取消息(消费者组模式) // 使用 gredis Do() 方法执行 XREADGROUP 命令 func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) { + glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP GROUP %s %s COUNT %d BLOCK %d STREAMS %s >", + groupName, consumerName, count, blockMs, streamKey) + // XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockMs STREAMS streamKey > result, err := GetRedisClient().Do(ctx, "XREADGROUP", "GROUP", groupName, consumerName, @@ -88,66 +108,89 @@ func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName stri ) if err != nil { + glog.Errorf(ctx, "[DEBUG Redis] XREADGROUP 错误: %v", err) return nil, err } - // 解析返回值 - // 格式: [[streamKey, [[msgID, [field1, value1, field2, value2, ...]], ...]]] + glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP 返回: %+v", result) + // 预分配容量,避免动态扩容 messages := make([]StreamMessage, 0, int(count)) - if result == nil { + if result == nil || result.IsEmpty() { // 超时或没有数据 return messages, nil } - // 类型断言:result.Val() 返回 interface{} - streamsArray, ok := result.Val().([]interface{}) - if !ok || len(streamsArray) == 0 { - return messages, nil - } + // GoFrame gredis 返回格式: map[streamKey:[[msgID [field1 value1 field2 value2 ...]] ...]] + resultVal := result.Val() - // 遍历每个 stream - for _, streamData := range streamsArray { - streamArray, ok := streamData.([]interface{}) - if !ok || len(streamArray) < 2 { - continue - } - - // streamArray[0] 是 streamKey, streamArray[1] 是消息数组 - messagesArray, ok := streamArray[1].([]interface{}) - if !ok { - continue - } - - // 解析每条消息 - for _, msgData := range messagesArray { - msgArray, ok := msgData.([]interface{}) - if !ok || len(msgArray) < 2 { - continue - } - - // msgArray[0] 是 ID, msgArray[1] 是字段数组 - msgID := gconv.String(msgArray[0]) - fieldsArray, ok := msgArray[1].([]interface{}) + // 尝试 map 格式(GoFrame gredis 返回) + if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok { + for _, streamMsgs := range streamsMap { + msgsArray, ok := streamMsgs.([]interface{}) if !ok { continue } - - // 解析字段为 map,预分配容量,避免动态扩容 - values := make(map[string]interface{}, len(fieldsArray)/2) - for i := 0; i < len(fieldsArray); i += 2 { - if i+1 < len(fieldsArray) { - key := gconv.String(fieldsArray[i]) - val := fieldsArray[i+1] - values[key] = val + for _, msgData := range msgsArray { + msgArray, ok := msgData.([]interface{}) + if !ok || len(msgArray) < 2 { + continue } + msgID := gconv.String(msgArray[0]) + fieldsArray, ok := msgArray[1].([]interface{}) + if !ok { + continue + } + values := make(map[string]interface{}, len(fieldsArray)/2) + for i := 0; i < len(fieldsArray); i += 2 { + if i+1 < len(fieldsArray) { + key := gconv.String(fieldsArray[i]) + values[key] = fieldsArray[i+1] + } + } + messages = append(messages, StreamMessage{ + ID: msgID, + Values: values, + }) } + } + return messages, nil + } - messages = append(messages, StreamMessage{ - ID: msgID, - Values: values, - }) + // 尝试数组格式(标准 Redis 返回) + if streamsArray, ok := resultVal.([]interface{}); ok && len(streamsArray) > 0 { + for _, streamData := range streamsArray { + streamArray, ok := streamData.([]interface{}) + if !ok || len(streamArray) < 2 { + continue + } + messagesArray, ok := streamArray[1].([]interface{}) + if !ok { + continue + } + for _, msgData := range messagesArray { + msgArray, ok := msgData.([]interface{}) + if !ok || len(msgArray) < 2 { + continue + } + msgID := gconv.String(msgArray[0]) + fieldsArray, ok := msgArray[1].([]interface{}) + if !ok { + continue + } + values := make(map[string]interface{}, len(fieldsArray)/2) + for i := 0; i < len(fieldsArray); i += 2 { + if i+1 < len(fieldsArray) { + key := gconv.String(fieldsArray[i]) + values[key] = fieldsArray[i+1] + } + } + messages = append(messages, StreamMessage{ + ID: msgID, + Values: values, + }) + } } } @@ -200,16 +243,16 @@ func GetPendingMessages(ctx context.Context, streamKey, groupName string, start, } if result == nil { - return []PendingMessage{}, nil + return nil, nil } // 解析返回值:[[ID, consumer, idle, retryCount], ...] pendingArray, ok := result.Val().([]interface{}) if !ok { - return []PendingMessage{}, nil + return nil, nil } - var messages []PendingMessage + messages := make([]PendingMessage, 0, len(pendingArray)) for _, item := range pendingArray { itemArray, ok := item.([]interface{}) if !ok || len(itemArray) < 4 { @@ -242,13 +285,13 @@ func ClaimPendingMessage(ctx context.Context, streamKey, groupName, consumerName } if result == nil { - return []StreamMessage{}, nil + return nil, nil } // 解析返回值:类似 XREADGROUP messagesArray, ok := result.Val().([]interface{}) if !ok { - return []StreamMessage{}, nil + return nil, nil } // 预分配容量,避免动态扩容 @@ -344,6 +387,43 @@ func SetSessionCache(ctx context.Context, userId, sessionId string) error { return err } +// 限流相关常量 +const ( + // RateLimitKeyPrefix 限流计数器 Key 前缀 + RateLimitKeyPrefix = "ragflow:ratelimit:" +) + +// IncrRateLimit 增加限流计数器,返回当前计数 +// windowSeconds: 时间窗口(秒) +func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) { + fullKey := RateLimitKeyPrefix + key + result, err := GetRedisClient().Do(ctx, "INCR", fullKey) + if err != nil { + return + } + count = result.Int64() + + // 首次设置过期时间 + if count == 1 { + GetRedisClient().Do(ctx, "EXPIRE", fullKey, windowSeconds) + } + return +} + +// GetRateLimit 获取当前限流计数 +func GetRateLimit(ctx context.Context, key string) (count int64, err error) { + fullKey := RateLimitKeyPrefix + key + result, err := GetRedisClient().Get(ctx, fullKey) + if err != nil { + return + } + if result.IsEmpty() { + return 0, nil + } + count = result.Int64() + return +} + // GetSessionCache 获取缓存的 RAGFlow Session ID // 使用 gredis Get 方法 func GetSessionCache(ctx context.Context, userId string) (string, error) { diff --git a/redis/types.go b/redis/types.go index 8671a1a..ecdce57 100644 --- a/redis/types.go +++ b/redis/types.go @@ -37,3 +37,68 @@ func (m *BatchStreamMessage) ToMap() map[string]interface{} { "index": m.Index, } } + +// ResponseStreamMessage RAGFlow 响应消息结构(写入结果 Stream) +type ResponseStreamMessage struct { + UserId string `json:"user_id"` // 用户ID + Platform string `json:"platform"` // 平台标识 + Question string `json:"question"` // 用户问题 + Content string `json:"content"` // RAGFlow 回复内容 + SessionId string `json:"session_id"` // RAGFlow Session ID + Timestamp int64 `json:"timestamp"` // 时间戳(秒) + MessageId string `json:"message_id"` // 原始消息ID +} + +// ToMap 转换为 map[string]interface{} 用于 Stream 存储 +func (m *ResponseStreamMessage) ToMap() map[string]interface{} { + return map[string]interface{}{ + "user_id": m.UserId, + "platform": m.Platform, + "question": m.Question, + "content": m.Content, + "session_id": m.SessionId, + "timestamp": m.Timestamp, + "message_id": m.MessageId, + } +} + +// FollowUpMessage 追问消息结构(RabbitMQ 延时队列) +type FollowUpMessage struct { + UserId string `json:"user_id"` // 用户ID + Platform string `json:"platform"` // 平台标识 + Content string `json:"content"` // 追问内容 + FollowUpType int `json:"follow_up_type"` // 追问类型:1=30s, 2=60s, 3=180s + Timestamp int64 `json:"timestamp"` // 发送时间戳 +} + +// 追问话术常量 +const ( + FollowUpType1 = 1 // 30秒追问 + FollowUpType2 = 2 // 60秒追问 + FollowUpType3 = 3 // 180秒追问 +) + +// 追问话术内容 +var FollowUpContents = map[int]string{ + FollowUpType1: "还有其他问题吗?", + FollowUpType2: "如果需要帮助,随时告诉我~", + FollowUpType3: "我一直在线,有问题随时找我~", +} + +// 追问延时时间(秒) +var FollowUpDelays = map[int]int{ + FollowUpType1: 30, + FollowUpType2: 60, + FollowUpType3: 180, +} + +// ArchiveMessage 会话归档消息结构(RabbitMQ 延时队列) +type ArchiveMessage struct { + UserId string `json:"user_id"` // 用户ID + Platform string `json:"platform"` // 平台标识 + SessionId string `json:"session_id"` // RAGFlow Session ID + Timestamp int64 `json:"timestamp"` // 发送时间戳 +} + +// 归档延时时间(秒) +const ArchiveDelaySeconds = 3600 // 60分钟