重构了一下 rag的方法, 使用 goframe的框架, 还有redis连接部分
This commit is contained in:
@@ -2,11 +2,12 @@ package rabbitmq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,13 +34,7 @@ func Init(ctx context.Context, cfg *Config) error {
|
|||||||
var err error
|
var err error
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
// 构建连接字符串
|
// 构建连接字符串
|
||||||
url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s",
|
url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost
|
||||||
cfg.Username,
|
|
||||||
cfg.Password,
|
|
||||||
cfg.Host,
|
|
||||||
cfg.Port,
|
|
||||||
cfg.VHost,
|
|
||||||
)
|
|
||||||
|
|
||||||
// 创建连接
|
// 创建连接
|
||||||
conn, err = amqp.Dial(url)
|
conn, err = amqp.Dial(url)
|
||||||
@@ -89,7 +84,7 @@ func GetChannel() (*amqp.Channel, error) {
|
|||||||
defer mu.RUnlock()
|
defer mu.RUnlock()
|
||||||
|
|
||||||
if channel == nil || channel.IsClosed() {
|
if channel == nil || channel.IsClosed() {
|
||||||
return nil, fmt.Errorf("RabbitMQ Channel 未初始化或已关闭")
|
return nil, gerror.New("RabbitMQ Channel 未初始化或已关闭")
|
||||||
}
|
}
|
||||||
|
|
||||||
return channel, nil
|
return channel, nil
|
||||||
@@ -101,7 +96,7 @@ func GetConnection() (*amqp.Connection, error) {
|
|||||||
defer mu.RUnlock()
|
defer mu.RUnlock()
|
||||||
|
|
||||||
if conn == nil || conn.IsClosed() {
|
if conn == nil || conn.IsClosed() {
|
||||||
return nil, fmt.Errorf("RabbitMQ 连接未初始化或已关闭")
|
return nil, gerror.New("RabbitMQ 连接未初始化或已关闭")
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@@ -160,13 +155,7 @@ func reconnect(ctx context.Context) {
|
|||||||
VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(),
|
VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(),
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s",
|
url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost
|
||||||
cfg.Username,
|
|
||||||
cfg.Password,
|
|
||||||
cfg.Host,
|
|
||||||
cfg.Port,
|
|
||||||
cfg.VHost,
|
|
||||||
)
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
conn, err = amqp.Dial(url)
|
conn, err = amqp.Dial(url)
|
||||||
@@ -190,7 +179,7 @@ func reconnect(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close 关闭连接
|
// Close 关闭连接
|
||||||
func Close(ctx context.Context) error {
|
func Close(ctx context.Context) (err error) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
@@ -201,21 +190,21 @@ func Close(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
if err := channel.Close(); err != nil {
|
if err = channel.Close(); err != nil {
|
||||||
g.Log().Errorf(ctx, "关闭 RabbitMQ Channel 失败: %v", err)
|
g.Log().Errorf(ctx, "关闭 RabbitMQ Channel 失败: %v", err)
|
||||||
}
|
}
|
||||||
channel = nil
|
channel = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
if err := conn.Close(); err != nil {
|
if err = conn.Close(); err != nil {
|
||||||
g.Log().Errorf(ctx, "关闭 RabbitMQ 连接失败: %v", err)
|
g.Log().Errorf(ctx, "关闭 RabbitMQ 连接失败: %v", err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
conn = nil
|
conn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
watcherStarted = false
|
watcherStarted = false
|
||||||
g.Log().Info(ctx, "RabbitMQ 连接已关闭")
|
g.Log().Info(ctx, "RabbitMQ 连接已关闭")
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package rabbitmq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
)
|
)
|
||||||
@@ -74,7 +74,7 @@ func NewConsumer(queue string, handler MessageHandler, opts ...ConsumerOption) *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start 启动消费者
|
// Start 启动消费者
|
||||||
func (c *Consumer) Start(ctx context.Context) error {
|
func (c *Consumer) Start(ctx context.Context) (err error) {
|
||||||
// 创建可取消的 context
|
// 创建可取消的 context
|
||||||
workerCtx, cancel := context.WithCancel(ctx)
|
workerCtx, cancel := context.WithCancel(ctx)
|
||||||
c.cancel = cancel
|
c.cancel = cancel
|
||||||
@@ -90,7 +90,7 @@ func (c *Consumer) Start(ctx context.Context) error {
|
|||||||
false, // global: false 表示仅应用于当前 channel
|
false, // global: false 表示仅应用于当前 channel
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("设置 QoS 失败: %v", err)
|
return gerror.Newf("设置 QoS 失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 开始消费
|
// 开始消费
|
||||||
@@ -104,7 +104,7 @@ func (c *Consumer) Start(ctx context.Context) error {
|
|||||||
nil, // args
|
nil, // args
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("开始消费失败: %v", err)
|
return gerror.Newf("开始消费失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "消费者已启动: queue=%s, prefetch=%d, workers=%d",
|
g.Log().Infof(ctx, "消费者已启动: queue=%s, prefetch=%d, workers=%d",
|
||||||
@@ -115,7 +115,7 @@ func (c *Consumer) Start(ctx context.Context) error {
|
|||||||
go c.worker(workerCtx, i, msgs)
|
go c.worker(workerCtx, i, msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker 工作协程
|
// worker 工作协程
|
||||||
@@ -168,8 +168,8 @@ func StartTypedConsumer[T any](
|
|||||||
// 包装处理函数
|
// 包装处理函数
|
||||||
wrappedHandler := func(ctx context.Context, body []byte) error {
|
wrappedHandler := func(ctx context.Context, body []byte) error {
|
||||||
var msg T
|
var msg T
|
||||||
if err := json.Unmarshal(body, &msg); err != nil {
|
if err := gjson.DecodeTo(body, &msg); err != nil {
|
||||||
return fmt.Errorf("反序列化消息失败: %v", err)
|
return gerror.Newf("反序列化消息失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return handler(ctx, &msg)
|
return handler(ctx, &msg)
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package rabbitmq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
)
|
)
|
||||||
@@ -24,16 +24,16 @@ func NewPublisher(exchange, routingKey string) *Publisher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Publish 发布消息
|
// Publish 发布消息
|
||||||
func (p *Publisher) Publish(ctx context.Context, message interface{}) error {
|
func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error) {
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 序列化消息
|
// 序列化消息
|
||||||
body, err := json.Marshal(message)
|
body, err := gjson.Encode(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("消息序列化失败: %v", err)
|
return gerror.Newf("消息序列化失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发布消息
|
// 发布消息
|
||||||
@@ -59,21 +59,21 @@ func (p *Publisher) Publish(ctx context.Context, message interface{}) error {
|
|||||||
g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s",
|
g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s",
|
||||||
p.exchange, p.routingKey)
|
p.exchange, p.routingKey)
|
||||||
|
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublishDelayed 发布延时消息
|
// PublishDelayed 发布延时消息
|
||||||
// delaySeconds: 延时秒数
|
// delaySeconds: 延时秒数
|
||||||
func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) error {
|
func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) (err error) {
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 序列化消息
|
// 序列化消息
|
||||||
body, err := json.Marshal(message)
|
body, err := gjson.Encode(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("消息序列化失败: %v", err)
|
return gerror.Newf("消息序列化失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发布延时消息(需要 rabbitmq_delayed_message_exchange 插件)
|
// 发布延时消息(需要 rabbitmq_delayed_message_exchange 插件)
|
||||||
@@ -102,13 +102,13 @@ func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, del
|
|||||||
g.Log().Debugf(ctx, "延时消息发布成功: exchange=%s, routingKey=%s, delay=%ds",
|
g.Log().Debugf(ctx, "延时消息发布成功: exchange=%s, routingKey=%s, delay=%ds",
|
||||||
p.exchange, p.routingKey, delaySeconds)
|
p.exchange, p.routingKey, delaySeconds)
|
||||||
|
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// PublishBatch 批量发布消息
|
// PublishBatch 批量发布消息
|
||||||
func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) error {
|
func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) (err error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
@@ -117,7 +117,7 @@ func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
body, err := json.Marshal(message)
|
body, err := gjson.Encode(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "消息 %d 序列化失败: %v", i, err)
|
g.Log().Errorf(ctx, "消息 %d 序列化失败: %v", i, err)
|
||||||
continue
|
continue
|
||||||
@@ -143,5 +143,5 @@ func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "批量发布完成: 共 %d 条消息", len(messages))
|
g.Log().Infof(ctx, "批量发布完成: 共 %d 条消息", len(messages))
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package rabbitmq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
amqp "github.com/rabbitmq/amqp091-go"
|
amqp "github.com/rabbitmq/amqp091-go"
|
||||||
)
|
)
|
||||||
@@ -35,7 +35,7 @@ type BindingConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeclareQueue 声明队列
|
// DeclareQueue 声明队列
|
||||||
func DeclareQueue(ctx context.Context, cfg *QueueConfig) error {
|
func DeclareQueue(ctx context.Context, cfg *QueueConfig) (err error) {
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -56,11 +56,11 @@ func DeclareQueue(ctx context.Context, cfg *QueueConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "队列声明成功: %s", cfg.Name)
|
g.Log().Infof(ctx, "队列声明成功: %s", cfg.Name)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeclareExchange 声明 Exchange
|
// DeclareExchange 声明 Exchange
|
||||||
func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) error {
|
func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) (err error) {
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -82,11 +82,11 @@ func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "Exchange 声明成功: %s (type=%s)", cfg.Name, cfg.Type)
|
g.Log().Infof(ctx, "Exchange 声明成功: %s (type=%s)", cfg.Name, cfg.Type)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindQueue 绑定队列到 Exchange
|
// BindQueue 绑定队列到 Exchange
|
||||||
func BindQueue(ctx context.Context, cfg *BindingConfig) error {
|
func BindQueue(ctx context.Context, cfg *BindingConfig) (err error) {
|
||||||
ch, err := GetChannel()
|
ch, err := GetChannel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -108,7 +108,7 @@ func BindQueue(ctx context.Context, cfg *BindingConfig) error {
|
|||||||
|
|
||||||
g.Log().Infof(ctx, "队列绑定成功: queue=%s → exchange=%s (routingKey=%s)",
|
g.Log().Infof(ctx, "队列绑定成功: queue=%s → exchange=%s (routingKey=%s)",
|
||||||
cfg.Queue, cfg.Exchange, cfg.RoutingKey)
|
cfg.Queue, cfg.Exchange, cfg.RoutingKey)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupDelayExchange 设置延时 Exchange(需要 rabbitmq_delayed_message_exchange 插件)
|
// SetupDelayExchange 设置延时 Exchange(需要 rabbitmq_delayed_message_exchange 插件)
|
||||||
@@ -165,9 +165,9 @@ func SetupQueueWithDLX(ctx context.Context, queueName, dlxExchange, dlxRoutingKe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetupBasicTopology 设置基础拓扑(适用于小红书客服场景)
|
// SetupBasicTopology 设置基础拓扑(适用于小红书客服场景)
|
||||||
func SetupBasicTopology(ctx context.Context) error {
|
func SetupBasicTopology(ctx context.Context) (err error) {
|
||||||
// 1. 声明普通 Exchange
|
// 1. 声明普通 Exchange
|
||||||
err := DeclareExchange(ctx, &ExchangeConfig{
|
err = DeclareExchange(ctx, &ExchangeConfig{
|
||||||
Name: "ragflow_exchange",
|
Name: "ragflow_exchange",
|
||||||
Type: "direct",
|
Type: "direct",
|
||||||
Durable: true,
|
Durable: true,
|
||||||
@@ -179,7 +179,7 @@ func SetupBasicTopology(ctx context.Context) error {
|
|||||||
// 2. 声明延时 Exchange
|
// 2. 声明延时 Exchange
|
||||||
err = SetupDelayExchange(ctx, "delay_exchange")
|
err = SetupDelayExchange(ctx, "delay_exchange")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("延时 Exchange 声明失败(可能未安装插件): %v", err)
|
return gerror.Newf("延时 Exchange 声明失败(可能未安装插件): %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 声明死信队列
|
// 3. 声明死信队列
|
||||||
@@ -227,5 +227,5 @@ func SetupBasicTopology(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Info(ctx, "RabbitMQ 拓扑结构设置完成")
|
g.Log().Info(ctx, "RabbitMQ 拓扑结构设置完成")
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Agent AGENT 管理
|
// Agent AGENT 管理
|
||||||
@@ -56,44 +57,44 @@ type ListAgentsRes struct {
|
|||||||
|
|
||||||
// CreateAgent 创建 Agent
|
// CreateAgent 创建 Agent
|
||||||
// POST /api/v1/agents
|
// POST /api/v1/agents
|
||||||
func (c *Client) CreateAgent(ctx context.Context, req *CreateAgentReq) error {
|
func (c *Client) CreateAgent(ctx context.Context, req *CreateAgentReq) (err error) {
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
if err := c.request(ctx, "POST", "/api/v1/agents", req, &res); err != nil {
|
if err = c.request(ctx, "POST", "/api/v1/agents", req, &res); err != nil {
|
||||||
return fmt.Errorf("create agent failed: %w", err)
|
return gerror.Newf("create agent failed: %v", err)
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("create agent failed: %s", res.Message)
|
return gerror.Newf("create agent failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAgent 更新 Agent
|
// UpdateAgent 更新 Agent
|
||||||
// PUT /api/v1/agents/{agent_id}
|
// PUT /api/v1/agents/{agent_id}
|
||||||
func (c *Client) UpdateAgent(ctx context.Context, agentID string, req *UpdateAgentReq) error {
|
func (c *Client) UpdateAgent(ctx context.Context, agentID string, req *UpdateAgentReq) (err error) {
|
||||||
path := fmt.Sprintf("/api/v1/agents/%s", agentID)
|
path := "/api/v1/agents/" + agentID
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
if err := c.request(ctx, "PUT", path, req, &res); err != nil {
|
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
|
||||||
return fmt.Errorf("update agent failed: %w", err)
|
return gerror.Newf("update agent failed: %v", err)
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("update agent failed: %s", res.Message)
|
return gerror.Newf("update agent failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAgent 删除 Agent
|
// DeleteAgent 删除 Agent
|
||||||
// DELETE /api/v1/agents/{agent_id}
|
// DELETE /api/v1/agents/{agent_id}
|
||||||
func (c *Client) DeleteAgent(ctx context.Context, agentID string) error {
|
func (c *Client) DeleteAgent(ctx context.Context, agentID string) (err error) {
|
||||||
path := fmt.Sprintf("/api/v1/agents/%s", agentID)
|
path := "/api/v1/agents/" + agentID
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
// 官方文档要求传空对象,不是 nil
|
// 官方文档要求传空对象,不是 nil
|
||||||
if err := c.request(ctx, "DELETE", path, map[string]interface{}{}, &res); err != nil {
|
if err = c.request(ctx, "DELETE", path, map[string]interface{}{}, &res); err != nil {
|
||||||
return fmt.Errorf("delete agent failed: %w", err)
|
return gerror.Newf("delete agent failed: %v", err)
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete agent failed: %s", res.Message)
|
return gerror.Newf("delete agent failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAgents 列出 Agent
|
// ListAgents 列出 Agent
|
||||||
@@ -131,10 +132,10 @@ func (c *Client) ListAgents(ctx context.Context, req *ListAgentsReq) (*ListAgent
|
|||||||
|
|
||||||
var res ListAgentsRes
|
var res ListAgentsRes
|
||||||
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
|
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
|
||||||
return nil, fmt.Errorf("list agents failed: %w", err)
|
return nil, gerror.Newf("list agents failed: %v", err)
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list agents failed: code=%d", res.Code)
|
return nil, gerror.Newf("list agents failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 聊天助手管理
|
// 聊天助手管理
|
||||||
@@ -104,7 +105,7 @@ func (c *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*Chat, err
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("create chat failed: %s", res.Msg)
|
return nil, gerror.Newf("create chat failed: %s", res.Msg)
|
||||||
}
|
}
|
||||||
return res.Data, nil
|
return res.Data, nil
|
||||||
}
|
}
|
||||||
@@ -144,33 +145,33 @@ func (c *Client) ListChats(ctx context.Context, req *ListChatsReq) (*ListChatsRe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list chats failed: code=%d", res.Code)
|
return nil, gerror.Newf("list chats failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteChats 删除聊天助手
|
// DeleteChats 删除聊天助手
|
||||||
func (c *Client) DeleteChats(ctx context.Context, ids []string) error {
|
func (c *Client) DeleteChats(ctx context.Context, ids []string) (err error) {
|
||||||
req := DeleteChatsReq{Ids: ids}
|
req := DeleteChatsReq{Ids: ids}
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
if err := c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil {
|
if err = c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete chats failed: %s", res.Message)
|
return gerror.Newf("delete chats failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChat 更新聊天助手
|
// UpdateChat 更新聊天助手
|
||||||
func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) error {
|
func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) (err error) {
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/chats/%s", id)
|
path := "/api/v1/chats/" + id
|
||||||
if err := c.request(ctx, "PUT", path, req, &res); err != nil {
|
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("update chat failed: %s", res.Message)
|
return gerror.Newf("update chat failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 数据集内知识块管理
|
// 数据集内知识块管理
|
||||||
@@ -90,7 +91,7 @@ type RetrieveChunksRes struct {
|
|||||||
|
|
||||||
// AddChunk 添加知识块
|
// AddChunk 添加知识块
|
||||||
func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req *AddChunkReq) (*Chunk, error) {
|
func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req *AddChunkReq) (*Chunk, error) {
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId)
|
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
|
||||||
var res struct {
|
var res struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Data struct {
|
Data struct {
|
||||||
@@ -102,14 +103,14 @@ func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("add chunk failed: %s", res.Msg)
|
return nil, gerror.Newf("add chunk failed: %s", res.Msg)
|
||||||
}
|
}
|
||||||
return res.Data.Chunk, nil
|
return res.Data.Chunk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListChunks 列出知识块
|
// ListChunks 列出知识块
|
||||||
func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, req *ListChunksReq) (*ListChunksRes, error) {
|
func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, req *ListChunksReq) (*ListChunksRes, error) {
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId)
|
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
|
||||||
params := map[string]interface{}{}
|
params := map[string]interface{}{}
|
||||||
if req.Keywords != "" {
|
if req.Keywords != "" {
|
||||||
params["keywords"] = req.Keywords
|
params["keywords"] = req.Keywords
|
||||||
@@ -134,36 +135,36 @@ func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, r
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list chunks failed: code=%d", res.Code)
|
return nil, gerror.Newf("list chunks failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteChunks 删除知识块
|
// DeleteChunks 删除知识块
|
||||||
func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) error {
|
func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) (err error) {
|
||||||
req := DeleteChunksReq{ChunkIds: chunkIds}
|
req := DeleteChunksReq{ChunkIds: chunkIds}
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId)
|
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
|
||||||
if err := c.request(ctx, "DELETE", path, req, &res); err != nil {
|
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete chunks failed: %s", res.Message)
|
return gerror.Newf("delete chunks failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChunk 更新知识块
|
// UpdateChunk 更新知识块
|
||||||
func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) error {
|
func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) (err error) {
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks/%s", datasetId, documentId, chunkId)
|
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks/" + chunkId
|
||||||
if err := c.request(ctx, "PUT", path, req, &res); err != nil {
|
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("update chunk failed: %s", res.Message)
|
return gerror.Newf("update chunk failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// RetrieveChunks 检索知识块
|
// RetrieveChunks 检索知识块
|
||||||
@@ -173,7 +174,7 @@ func (c *Client) RetrieveChunks(ctx context.Context, req *RetrieveChunksReq) (*R
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("retrieve chunks failed: code=%d", res.Code)
|
return nil, gerror.Newf("retrieve chunks failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,12 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/net/gclient"
|
"github.com/gogf/gf/v2/net/gclient"
|
||||||
)
|
)
|
||||||
@@ -33,7 +32,7 @@ func init() {
|
|||||||
|
|
||||||
// 初始化全局客户端
|
// 初始化全局客户端
|
||||||
httpClient := gclient.New()
|
httpClient := gclient.New()
|
||||||
httpClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
httpClient.SetHeader("Authorization", "Bearer "+apiKey)
|
||||||
httpClient.SetHeader("Content-Type", "application/json")
|
httpClient.SetHeader("Content-Type", "application/json")
|
||||||
|
|
||||||
globalClient = &Client{
|
globalClient = &Client{
|
||||||
@@ -79,20 +78,19 @@ func (r *CommonResponse) IsSuccess() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// request 发送 HTTP 请求
|
// request 发送 HTTP 请求
|
||||||
func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
|
func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) (err error) {
|
||||||
fullURL := c.BaseURL + path
|
fullURL := c.BaseURL + path
|
||||||
|
|
||||||
var reqBody io.Reader
|
var reqBody string
|
||||||
if body != nil {
|
if body != nil {
|
||||||
jsonData, err := json.Marshal(body)
|
jsonData, err := gjson.Encode(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshal request body failed: %w", err)
|
return gerror.Newf("marshal request body failed: %v", err)
|
||||||
}
|
}
|
||||||
reqBody = strings.NewReader(string(jsonData))
|
reqBody = string(jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp *gclient.Response
|
var resp *gclient.Response
|
||||||
var err error
|
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
@@ -104,28 +102,24 @@ func (c *Client) request(ctx context.Context, method, path string, body interfac
|
|||||||
case "DELETE":
|
case "DELETE":
|
||||||
resp, err = c.HTTPClient.Delete(ctx, fullURL, reqBody)
|
resp, err = c.HTTPClient.Delete(ctx, fullURL, reqBody)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported method: %s", method)
|
return gerror.Newf("unsupported method: %s", method)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("http request failed: %w", err)
|
return gerror.Newf("http request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Close()
|
defer resp.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("http request failed with status: %d", resp.StatusCode)
|
return gerror.Newf("http request failed with status: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody := resp.ReadAll()
|
respBody := resp.ReadAll()
|
||||||
if err != nil {
|
if err = gjson.DecodeTo(respBody, result); err != nil {
|
||||||
return fmt.Errorf("read response body failed: %w", err)
|
return gerror.Newf("unmarshal response failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(respBody, result); err != nil {
|
return
|
||||||
return fmt.Errorf("unmarshal response failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildQueryString 构建查询字符串
|
// buildQueryString 构建查询字符串
|
||||||
@@ -134,9 +128,9 @@ func buildQueryString(params map[string]interface{}) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var parts []string
|
parts := make([]string, 0, len(params))
|
||||||
for k, v := range params {
|
for k, v := range params {
|
||||||
parts = append(parts, fmt.Sprintf("%s=%v", url.QueryEscape(k), url.QueryEscape(fmt.Sprintf("%v", v))))
|
parts = append(parts, url.QueryEscape(k)+"="+url.QueryEscape(g.NewVar(v).String()))
|
||||||
}
|
}
|
||||||
return strings.Join(parts, "&")
|
return strings.Join(parts, "&")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 数据集管理
|
// 数据集管理
|
||||||
@@ -90,7 +91,7 @@ func (c *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*Dat
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("create dataset failed: %s", res.Msg)
|
return nil, gerror.Newf("create dataset failed: %s", res.Msg)
|
||||||
}
|
}
|
||||||
return res.Data, nil
|
return res.Data, nil
|
||||||
}
|
}
|
||||||
@@ -134,33 +135,33 @@ func (c *Client) ListDatasets(ctx context.Context, req *ListDatasetsReq) (*ListD
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list datasets failed: code=%d", res.Code)
|
return nil, gerror.Newf("list datasets failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDataset 删除数据集
|
// DeleteDataset 删除数据集
|
||||||
func (c *Client) DeleteDataset(ctx context.Context, ids []string) error {
|
func (c *Client) DeleteDataset(ctx context.Context, ids []string) (err error) {
|
||||||
req := DeleteDatasetsReq{Ids: ids}
|
req := DeleteDatasetsReq{Ids: ids}
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
if err := c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil {
|
if err = c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete dataset failed: %s", res.Message)
|
return gerror.Newf("delete dataset failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDataset 更新数据集
|
// UpdateDataset 更新数据集
|
||||||
func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) error {
|
func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) (err error) {
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s", id)
|
path := "/api/v1/datasets/" + id
|
||||||
if err := c.request(ctx, "PUT", path, req, &res); err != nil {
|
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("update dataset failed: %s", res.Message)
|
return gerror.Newf("update dataset failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 数据集内文件管理
|
// 数据集内文件管理
|
||||||
@@ -70,7 +71,7 @@ type DeleteDocumentsReq struct {
|
|||||||
|
|
||||||
// ListDocuments 列出文档
|
// ListDocuments 列出文档
|
||||||
func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListDocumentsReq) (*ListDocumentsRes, error) {
|
func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListDocumentsReq) (*ListDocumentsRes, error) {
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents", datasetId)
|
path := "/api/v1/datasets/" + datasetId + "/documents"
|
||||||
params := map[string]interface{}{}
|
params := map[string]interface{}{}
|
||||||
if req.Page > 0 {
|
if req.Page > 0 {
|
||||||
params["page"] = req.Page
|
params["page"] = req.Page
|
||||||
@@ -111,16 +112,14 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD
|
|||||||
|
|
||||||
// 处理数组参数:suffix(文件后缀过滤)
|
// 处理数组参数:suffix(文件后缀过滤)
|
||||||
// API 要求多个值时重复参数名,如:suffix=pdf&suffix=txt
|
// API 要求多个值时重复参数名,如:suffix=pdf&suffix=txt
|
||||||
// 这里使用 fmt.Sprintf 来构造每个参数值
|
|
||||||
for _, suffix := range req.Suffix {
|
for _, suffix := range req.Suffix {
|
||||||
queryParts = append(queryParts, fmt.Sprintf("suffix=%s", suffix))
|
queryParts = append(queryParts, "suffix="+suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理数组参数:run(处理状态过滤)
|
// 处理数组参数:run(处理状态过滤)
|
||||||
// 支持数字格式("0"-"4")或文本格式("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL")
|
// 支持数字格式("0"-"4")或文本格式("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL")
|
||||||
// 这里使用 fmt.Sprintf 来构造每个参数值
|
|
||||||
for _, run := range req.Run {
|
for _, run := range req.Run {
|
||||||
queryParts = append(queryParts, fmt.Sprintf("run=%s", run))
|
queryParts = append(queryParts, "run="+run)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构造最终请求路径
|
// 构造最终请求路径
|
||||||
@@ -134,7 +133,7 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list documents failed: code=%d", res.Code)
|
return nil, gerror.Newf("list documents failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
@@ -142,23 +141,21 @@ func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListD
|
|||||||
// UploadDocument 上传文档
|
// UploadDocument 上传文档
|
||||||
// 注意:此方法需要特殊处理 multipart/form-data,目前的 request 方法可能不支持
|
// 注意:此方法需要特殊处理 multipart/form-data,目前的 request 方法可能不支持
|
||||||
// 我们需要扩展 request 方法或在此处单独实现
|
// 我们需要扩展 request 方法或在此处单独实现
|
||||||
func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) error {
|
func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) (err error) {
|
||||||
// TODO: 实现文件上传逻辑,需要使用 gclient 的 UploadFile 功能
|
// TODO: 实现文件上传逻辑,需要使用 gclient 的 UploadFile 功能
|
||||||
// 由于 request 方法封装了 JSON 处理,这里可能需要绕过 request 方法直接使用 c.Client
|
return gerror.New("upload document not implemented yet")
|
||||||
// 暂时留空或仅做简单提示,待完善 Client 封装以支持文件上传
|
|
||||||
return fmt.Errorf("upload document not implemented yet")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDocument 删除文档
|
// DeleteDocument 删除文档
|
||||||
func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) error {
|
func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) (err error) {
|
||||||
req := DeleteDocumentsReq{Ids: ids}
|
req := DeleteDocumentsReq{Ids: ids}
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/datasets/%s/documents", datasetId)
|
path := "/api/v1/datasets/" + datasetId + "/documents"
|
||||||
if err := c.request(ctx, "DELETE", path, req, &res); err != nil {
|
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete document failed: %s", res.Message)
|
return gerror.Newf("delete document failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAICompatibleAPI 与 OpenAI 兼容的 API
|
// OpenAICompatibleAPI 与 OpenAI 兼容的 API
|
||||||
@@ -64,11 +65,11 @@ type ChatCompletionChunk struct {
|
|||||||
// CreateChatCompletion 创建聊天补全(与聊天助手)
|
// CreateChatCompletion 创建聊天补全(与聊天助手)
|
||||||
// POST /api/v1/chats_openai/{chat_id}/chat/completions
|
// POST /api/v1/chats_openai/{chat_id}/chat/completions
|
||||||
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||||
path := fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID)
|
path := "/api/v1/chats_openai/" + chatID + "/chat/completions"
|
||||||
|
|
||||||
var resp ChatCompletionResponse
|
var resp ChatCompletionResponse
|
||||||
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
||||||
return nil, fmt.Errorf("create chat completion failed: %w", err)
|
return nil, gerror.Newf("create chat completion failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
@@ -77,11 +78,11 @@ func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *C
|
|||||||
// CreateAgentCompletion 创建 Agent 补全
|
// CreateAgentCompletion 创建 Agent 补全
|
||||||
// POST /api/v1/agents_openai/{agent_id}/chat/completions
|
// POST /api/v1/agents_openai/{agent_id}/chat/completions
|
||||||
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||||
path := fmt.Sprintf("/api/v1/agents_openai/%s/chat/completions", agentID)
|
path := "/api/v1/agents_openai/" + agentID + "/chat/completions"
|
||||||
|
|
||||||
var resp ChatCompletionResponse
|
var resp ChatCompletionResponse
|
||||||
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
||||||
return nil, fmt.Errorf("create agent completion failed: %w", err)
|
return nil, gerror.Newf("create agent completion failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
@@ -91,31 +92,26 @@ func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req
|
|||||||
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
|
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
|
||||||
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
|
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
|
||||||
req.Stream = true
|
req.Stream = true
|
||||||
_ = fmt.Sprintf("/api/v1/chats_openai/%s/chat/completions", chatID)
|
|
||||||
|
|
||||||
// TODO: 实现流式读取逻辑
|
// TODO: 实现流式读取逻辑
|
||||||
return nil, fmt.Errorf("stream mode not implemented yet")
|
return nil, gerror.New("stream mode not implemented yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamReader 流式响应读取器
|
// StreamReader 流式响应读取器
|
||||||
type StreamReader struct {
|
type StreamReader struct {
|
||||||
decoder *json.Decoder
|
_ *gjson.Json // TODO: 实现流式读取时使用
|
||||||
close func() error
|
close func() error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadChunk 读取下一个响应块
|
// ReadChunk 读取下一个响应块
|
||||||
|
// TODO: 实现流式读取逻辑
|
||||||
func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) {
|
func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) {
|
||||||
var chunk ChatCompletionChunk
|
return nil, gerror.New("stream mode not implemented yet")
|
||||||
if err := sr.decoder.Decode(&chunk); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &chunk, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close 关闭流
|
// Close 关闭流
|
||||||
func (sr *StreamReader) Close() error {
|
func (sr *StreamReader) Close() (err error) {
|
||||||
if sr.close != nil {
|
if sr.close != nil {
|
||||||
return sr.close()
|
return sr.close()
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 会话管理
|
// 会话管理
|
||||||
@@ -76,7 +77,7 @@ type ChatCompletionRes struct {
|
|||||||
|
|
||||||
// CreateSession 创建会话
|
// CreateSession 创建会话
|
||||||
func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) {
|
func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) {
|
||||||
path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId)
|
path := "/api/v1/chats/" + chatId + "/sessions"
|
||||||
var res struct {
|
var res struct {
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
Data *Session `json:"data"`
|
Data *Session `json:"data"`
|
||||||
@@ -86,14 +87,14 @@ func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("create session failed: %s", res.Msg)
|
return nil, gerror.Newf("create session failed: %s", res.Msg)
|
||||||
}
|
}
|
||||||
return res.Data, nil
|
return res.Data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListSessions 列出会话
|
// ListSessions 列出会话
|
||||||
func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) {
|
func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) {
|
||||||
path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId)
|
path := "/api/v1/chats/" + chatId + "/sessions"
|
||||||
params := map[string]interface{}{}
|
params := map[string]interface{}{}
|
||||||
if req.Page > 0 {
|
if req.Page > 0 {
|
||||||
params["page"] = req.Page
|
params["page"] = req.Page
|
||||||
@@ -129,40 +130,40 @@ func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("list sessions failed: code=%d", res.Code)
|
return nil, gerror.Newf("list sessions failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSessions 删除会话
|
// DeleteSessions 删除会话
|
||||||
func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) error {
|
func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) (err error) {
|
||||||
req := DeleteSessionsReq{Ids: ids}
|
req := DeleteSessionsReq{Ids: ids}
|
||||||
var res CommonResponse
|
var res CommonResponse
|
||||||
path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId)
|
path := "/api/v1/chats/" + chatId + "/sessions"
|
||||||
if err := c.request(ctx, "DELETE", path, req, &res); err != nil {
|
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if !res.IsSuccess() {
|
if !res.IsSuccess() {
|
||||||
return fmt.Errorf("delete sessions failed: %s", res.Message)
|
return gerror.Newf("delete sessions failed: %s", res.Message)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatCompletion 对话 (目前仅支持非流式)
|
// ChatCompletion 对话 (目前仅支持非流式)
|
||||||
func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) {
|
func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) {
|
||||||
path := fmt.Sprintf("/api/v1/chats/%s/completions", chatId)
|
path := "/api/v1/chats/" + chatId + "/completions"
|
||||||
var res ChatCompletionRes
|
var res ChatCompletionRes
|
||||||
|
|
||||||
// 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式
|
// 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式
|
||||||
if req.Stream {
|
if req.Stream {
|
||||||
return nil, fmt.Errorf("stream mode not supported yet")
|
return nil, gerror.New("stream mode not supported yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.request(ctx, "POST", path, req, &res); err != nil {
|
if err := c.request(ctx, "POST", path, req, &res); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if res.Code != 0 {
|
if res.Code != 0 {
|
||||||
return nil, fmt.Errorf("chat completion failed: code=%d", res.Code)
|
return nil, gerror.Newf("chat completion failed: code=%d", res.Code)
|
||||||
}
|
}
|
||||||
return &res, nil
|
return &res, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ package ragflow
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
"github.com/gogf/gf/v2/errors/gerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
// System 系统管理
|
// System 系统管理
|
||||||
@@ -23,7 +24,7 @@ type HealthStatus struct {
|
|||||||
func (c *Client) CheckHealth(ctx context.Context) (*HealthStatus, error) {
|
func (c *Client) CheckHealth(ctx context.Context) (*HealthStatus, error) {
|
||||||
var status HealthStatus
|
var status HealthStatus
|
||||||
if err := c.request(ctx, "GET", "/v1/system/healthz", nil, &status); err != nil {
|
if err := c.request(ctx, "GET", "/v1/system/healthz", nil, &status); err != nil {
|
||||||
return nil, fmt.Errorf("check health failed: %w", err)
|
return nil, gerror.Newf("check health failed: %v", err)
|
||||||
}
|
}
|
||||||
return &status, nil
|
return &status, nil
|
||||||
}
|
}
|
||||||
@@ -36,4 +37,3 @@ func (c *Client) IsHealthy(ctx context.Context) (bool, error) {
|
|||||||
}
|
}
|
||||||
return status.Status == "ok", nil
|
return status.Status == "ok", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -149,12 +149,18 @@ func (q *QueueProcessor) Start(ctx context.Context) error {
|
|||||||
glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 超时: %dms",
|
glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 超时: %dms",
|
||||||
q.streamKey, q.groupName, q.consumerName, q.timeout)
|
q.streamKey, q.groupName, q.consumerName, q.timeout)
|
||||||
|
|
||||||
|
loopCount := 0
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-q.stopChan:
|
case <-q.stopChan:
|
||||||
glog.Info(ctx, "Stream 处理器收到停止信号")
|
glog.Info(ctx, "Stream 处理器收到停止信号")
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
|
loopCount++
|
||||||
|
if loopCount%10 == 1 {
|
||||||
|
glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,准备读取消息...", loopCount)
|
||||||
|
}
|
||||||
|
|
||||||
// 从 Redis Stream 中读取消息
|
// 从 Redis Stream 中读取消息
|
||||||
messages, err := q.fetchMessages(ctx)
|
messages, err := q.fetchMessages(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -164,11 +170,17 @@ func (q *QueueProcessor) Start(ctx context.Context) error {
|
|||||||
|
|
||||||
// 没有新消息,继续等待
|
// 没有新消息,继续等待
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
|
if loopCount%10 == 1 {
|
||||||
|
glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,无新消息", loopCount)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
glog.Infof(ctx, "[DEBUG] 收到 %d 条消息", len(messages))
|
||||||
|
|
||||||
// 处理每条消息
|
// 处理每条消息
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
|
glog.Infof(ctx, "[DEBUG] 处理消息 ID: %s, Values: %+v", msg.ID, msg.Values)
|
||||||
// 提交到协程池处理
|
// 提交到协程池处理
|
||||||
if err := q.submitTask(ctx, msg); err != nil {
|
if err := q.submitTask(ctx, msg); err != nil {
|
||||||
glog.Errorf(ctx, "提交任务到协程池失败: %v, 消息ID: %s", err, msg.ID)
|
glog.Errorf(ctx, "提交任务到协程池失败: %v, 消息ID: %s", err, msg.ID)
|
||||||
|
|||||||
168
redis/redis.go
168
redis/redis.go
@@ -3,33 +3,50 @@ package redis
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/database/gredis"
|
"github.com/gogf/gf/v2/database/gredis"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
"github.com/gogf/gf/v2/os/gtime"
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GRedisClient GoFrame gredis 客户端,统一使用(懒加载)
|
var (
|
||||||
var GRedisClient *gredis.Redis
|
// redisClient 单例 Redis 客户端
|
||||||
|
redisClient *gredis.Redis
|
||||||
|
// redisOnce 确保只初始化一次
|
||||||
|
redisOnce sync.Once
|
||||||
|
// RedisClient 兼容导出(供 mongo.go 使用)
|
||||||
|
// 注意:这是一个指向单例的指针,首次调用 GetRedisClient() 后生效
|
||||||
|
RedisClient *gredis.Redis
|
||||||
|
)
|
||||||
|
|
||||||
// RedisClient GRedisClient 的别名,保持向后兼容
|
// GetRedisClient 获取 Redis 客户端(单例模式)
|
||||||
var RedisClient *gredis.Redis
|
|
||||||
|
|
||||||
// GetRedisClient 获取 Redis 客户端(懒加载)
|
|
||||||
func GetRedisClient() *gredis.Redis {
|
func GetRedisClient() *gredis.Redis {
|
||||||
if GRedisClient == nil {
|
redisOnce.Do(func() {
|
||||||
GRedisClient = g.Redis()
|
redisClient = g.Redis()
|
||||||
RedisClient = GRedisClient
|
RedisClient = redisClient // 同步更新兼容导出
|
||||||
}
|
})
|
||||||
return GRedisClient
|
return redisClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// init 包初始化时自动初始化 Redis 客户端
|
||||||
|
func init() {
|
||||||
|
GetRedisClient()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream 和消费者组常量
|
// Stream 和消费者组常量
|
||||||
const (
|
const (
|
||||||
// RAGFlow 请求 Stream Key
|
// RAGFlow 请求 Stream Key
|
||||||
RAGFlowRequestStreamKey = "ragflow:request:stream"
|
RAGFlowRequestStreamKey = "ragflow:request:stream"
|
||||||
// RAGFlow 消费者组名称
|
// RAGFlow 响应 Stream Key
|
||||||
|
RAGFlowResponseStreamKey = "ragflow:response:stream"
|
||||||
|
// RAGFlow 请求消费者组名称
|
||||||
|
RAGFlowRequestConsumerGroup = "ragflow:request:consumer:group"
|
||||||
|
// RAGFlow 响应消费者组名称
|
||||||
|
RAGFlowResponseConsumerGroup = "ragflow:response:consumer:group"
|
||||||
|
// RAGFlow 消费者组名称(兼容旧代码)
|
||||||
RAGFlowConsumerGroup = "ragflow:consumer:group"
|
RAGFlowConsumerGroup = "ragflow:consumer:group"
|
||||||
// 会话最后活跃时间 Key 前缀
|
// 会话最后活跃时间 Key 前缀
|
||||||
SessionLastActiveKeyPrefix = "ragflow:session:"
|
SessionLastActiveKeyPrefix = "ragflow:session:"
|
||||||
@@ -79,6 +96,9 @@ func AddToStream(ctx context.Context, streamKey string, values map[string]interf
|
|||||||
// ReadFromStream 从 Stream 读取消息(消费者组模式)
|
// ReadFromStream 从 Stream 读取消息(消费者组模式)
|
||||||
// 使用 gredis Do() 方法执行 XREADGROUP 命令
|
// 使用 gredis Do() 方法执行 XREADGROUP 命令
|
||||||
func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) {
|
func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) {
|
||||||
|
glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP GROUP %s %s COUNT %d BLOCK %d STREAMS %s >",
|
||||||
|
groupName, consumerName, count, blockMs, streamKey)
|
||||||
|
|
||||||
// XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockMs STREAMS streamKey >
|
// XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockMs STREAMS streamKey >
|
||||||
result, err := GetRedisClient().Do(ctx,
|
result, err := GetRedisClient().Do(ctx,
|
||||||
"XREADGROUP", "GROUP", groupName, consumerName,
|
"XREADGROUP", "GROUP", groupName, consumerName,
|
||||||
@@ -88,68 +108,91 @@ func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName stri
|
|||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "[DEBUG Redis] XREADGROUP 错误: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析返回值
|
glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP 返回: %+v", result)
|
||||||
// 格式: [[streamKey, [[msgID, [field1, value1, field2, value2, ...]], ...]]]
|
|
||||||
// 预分配容量,避免动态扩容
|
// 预分配容量,避免动态扩容
|
||||||
messages := make([]StreamMessage, 0, int(count))
|
messages := make([]StreamMessage, 0, int(count))
|
||||||
|
|
||||||
if result == nil {
|
if result == nil || result.IsEmpty() {
|
||||||
// 超时或没有数据
|
// 超时或没有数据
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 类型断言:result.Val() 返回 interface{}
|
// GoFrame gredis 返回格式: map[streamKey:[[msgID [field1 value1 field2 value2 ...]] ...]]
|
||||||
streamsArray, ok := result.Val().([]interface{})
|
resultVal := result.Val()
|
||||||
if !ok || len(streamsArray) == 0 {
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 遍历每个 stream
|
// 尝试 map 格式(GoFrame gredis 返回)
|
||||||
for _, streamData := range streamsArray {
|
if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok {
|
||||||
streamArray, ok := streamData.([]interface{})
|
for _, streamMsgs := range streamsMap {
|
||||||
if !ok || len(streamArray) < 2 {
|
msgsArray, ok := streamMsgs.([]interface{})
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamArray[0] 是 streamKey, streamArray[1] 是消息数组
|
|
||||||
messagesArray, ok := streamArray[1].([]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
for _, msgData := range msgsArray {
|
||||||
// 解析每条消息
|
|
||||||
for _, msgData := range messagesArray {
|
|
||||||
msgArray, ok := msgData.([]interface{})
|
msgArray, ok := msgData.([]interface{})
|
||||||
if !ok || len(msgArray) < 2 {
|
if !ok || len(msgArray) < 2 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// msgArray[0] 是 ID, msgArray[1] 是字段数组
|
|
||||||
msgID := gconv.String(msgArray[0])
|
msgID := gconv.String(msgArray[0])
|
||||||
fieldsArray, ok := msgArray[1].([]interface{})
|
fieldsArray, ok := msgArray[1].([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析字段为 map,预分配容量,避免动态扩容
|
|
||||||
values := make(map[string]interface{}, len(fieldsArray)/2)
|
values := make(map[string]interface{}, len(fieldsArray)/2)
|
||||||
for i := 0; i < len(fieldsArray); i += 2 {
|
for i := 0; i < len(fieldsArray); i += 2 {
|
||||||
if i+1 < len(fieldsArray) {
|
if i+1 < len(fieldsArray) {
|
||||||
key := gconv.String(fieldsArray[i])
|
key := gconv.String(fieldsArray[i])
|
||||||
val := fieldsArray[i+1]
|
values[key] = fieldsArray[i+1]
|
||||||
values[key] = val
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, StreamMessage{
|
messages = append(messages, StreamMessage{
|
||||||
ID: msgID,
|
ID: msgID,
|
||||||
Values: values,
|
Values: values,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试数组格式(标准 Redis 返回)
|
||||||
|
if streamsArray, ok := resultVal.([]interface{}); ok && len(streamsArray) > 0 {
|
||||||
|
for _, streamData := range streamsArray {
|
||||||
|
streamArray, ok := streamData.([]interface{})
|
||||||
|
if !ok || len(streamArray) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messagesArray, ok := streamArray[1].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, msgData := range messagesArray {
|
||||||
|
msgArray, ok := msgData.([]interface{})
|
||||||
|
if !ok || len(msgArray) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msgID := gconv.String(msgArray[0])
|
||||||
|
fieldsArray, ok := msgArray[1].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
values := make(map[string]interface{}, len(fieldsArray)/2)
|
||||||
|
for i := 0; i < len(fieldsArray); i += 2 {
|
||||||
|
if i+1 < len(fieldsArray) {
|
||||||
|
key := gconv.String(fieldsArray[i])
|
||||||
|
values[key] = fieldsArray[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, StreamMessage{
|
||||||
|
ID: msgID,
|
||||||
|
Values: values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
@@ -200,16 +243,16 @@ func GetPendingMessages(ctx context.Context, streamKey, groupName string, start,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
return []PendingMessage{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析返回值:[[ID, consumer, idle, retryCount], ...]
|
// 解析返回值:[[ID, consumer, idle, retryCount], ...]
|
||||||
pendingArray, ok := result.Val().([]interface{})
|
pendingArray, ok := result.Val().([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return []PendingMessage{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []PendingMessage
|
messages := make([]PendingMessage, 0, len(pendingArray))
|
||||||
for _, item := range pendingArray {
|
for _, item := range pendingArray {
|
||||||
itemArray, ok := item.([]interface{})
|
itemArray, ok := item.([]interface{})
|
||||||
if !ok || len(itemArray) < 4 {
|
if !ok || len(itemArray) < 4 {
|
||||||
@@ -242,13 +285,13 @@ func ClaimPendingMessage(ctx context.Context, streamKey, groupName, consumerName
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
return []StreamMessage{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析返回值:类似 XREADGROUP
|
// 解析返回值:类似 XREADGROUP
|
||||||
messagesArray, ok := result.Val().([]interface{})
|
messagesArray, ok := result.Val().([]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return []StreamMessage{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 预分配容量,避免动态扩容
|
// 预分配容量,避免动态扩容
|
||||||
@@ -344,6 +387,43 @@ func SetSessionCache(ctx context.Context, userId, sessionId string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 限流相关常量
|
||||||
|
const (
|
||||||
|
// RateLimitKeyPrefix 限流计数器 Key 前缀
|
||||||
|
RateLimitKeyPrefix = "ragflow:ratelimit:"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IncrRateLimit 增加限流计数器,返回当前计数
|
||||||
|
// windowSeconds: 时间窗口(秒)
|
||||||
|
func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) {
|
||||||
|
fullKey := RateLimitKeyPrefix + key
|
||||||
|
result, err := GetRedisClient().Do(ctx, "INCR", fullKey)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
count = result.Int64()
|
||||||
|
|
||||||
|
// 首次设置过期时间
|
||||||
|
if count == 1 {
|
||||||
|
GetRedisClient().Do(ctx, "EXPIRE", fullKey, windowSeconds)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRateLimit 获取当前限流计数
|
||||||
|
func GetRateLimit(ctx context.Context, key string) (count int64, err error) {
|
||||||
|
fullKey := RateLimitKeyPrefix + key
|
||||||
|
result, err := GetRedisClient().Get(ctx, fullKey)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result.IsEmpty() {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
count = result.Int64()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// GetSessionCache 获取缓存的 RAGFlow Session ID
|
// GetSessionCache 获取缓存的 RAGFlow Session ID
|
||||||
// 使用 gredis Get 方法
|
// 使用 gredis Get 方法
|
||||||
func GetSessionCache(ctx context.Context, userId string) (string, error) {
|
func GetSessionCache(ctx context.Context, userId string) (string, error) {
|
||||||
|
|||||||
@@ -37,3 +37,68 @@ func (m *BatchStreamMessage) ToMap() map[string]interface{} {
|
|||||||
"index": m.Index,
|
"index": m.Index,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResponseStreamMessage RAGFlow 响应消息结构(写入结果 Stream)
|
||||||
|
type ResponseStreamMessage struct {
|
||||||
|
UserId string `json:"user_id"` // 用户ID
|
||||||
|
Platform string `json:"platform"` // 平台标识
|
||||||
|
Question string `json:"question"` // 用户问题
|
||||||
|
Content string `json:"content"` // RAGFlow 回复内容
|
||||||
|
SessionId string `json:"session_id"` // RAGFlow Session ID
|
||||||
|
Timestamp int64 `json:"timestamp"` // 时间戳(秒)
|
||||||
|
MessageId string `json:"message_id"` // 原始消息ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap 转换为 map[string]interface{} 用于 Stream 存储
|
||||||
|
func (m *ResponseStreamMessage) ToMap() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"user_id": m.UserId,
|
||||||
|
"platform": m.Platform,
|
||||||
|
"question": m.Question,
|
||||||
|
"content": m.Content,
|
||||||
|
"session_id": m.SessionId,
|
||||||
|
"timestamp": m.Timestamp,
|
||||||
|
"message_id": m.MessageId,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FollowUpMessage 追问消息结构(RabbitMQ 延时队列)
|
||||||
|
type FollowUpMessage struct {
|
||||||
|
UserId string `json:"user_id"` // 用户ID
|
||||||
|
Platform string `json:"platform"` // 平台标识
|
||||||
|
Content string `json:"content"` // 追问内容
|
||||||
|
FollowUpType int `json:"follow_up_type"` // 追问类型:1=30s, 2=60s, 3=180s
|
||||||
|
Timestamp int64 `json:"timestamp"` // 发送时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追问话术常量
|
||||||
|
const (
|
||||||
|
FollowUpType1 = 1 // 30秒追问
|
||||||
|
FollowUpType2 = 2 // 60秒追问
|
||||||
|
FollowUpType3 = 3 // 180秒追问
|
||||||
|
)
|
||||||
|
|
||||||
|
// 追问话术内容
|
||||||
|
var FollowUpContents = map[int]string{
|
||||||
|
FollowUpType1: "还有其他问题吗?",
|
||||||
|
FollowUpType2: "如果需要帮助,随时告诉我~",
|
||||||
|
FollowUpType3: "我一直在线,有问题随时找我~",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 追问延时时间(秒)
|
||||||
|
var FollowUpDelays = map[int]int{
|
||||||
|
FollowUpType1: 30,
|
||||||
|
FollowUpType2: 60,
|
||||||
|
FollowUpType3: 180,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ArchiveMessage 会话归档消息结构(RabbitMQ 延时队列)
|
||||||
|
type ArchiveMessage struct {
|
||||||
|
UserId string `json:"user_id"` // 用户ID
|
||||||
|
Platform string `json:"platform"` // 平台标识
|
||||||
|
SessionId string `json:"session_id"` // RAGFlow Session ID
|
||||||
|
Timestamp int64 `json:"timestamp"` // 发送时间戳
|
||||||
|
}
|
||||||
|
|
||||||
|
// 归档延时时间(秒)
|
||||||
|
const ArchiveDelaySeconds = 3600 // 60分钟
|
||||||
|
|||||||
Reference in New Issue
Block a user