refactor: 移除 Ragflow 和 NATS 相关代码

This commit is contained in:
2026-04-15 16:55:15 +08:00
parent f09cc8640d
commit bfdfe9d896
38 changed files with 70 additions and 6795 deletions

View File

@@ -16,7 +16,6 @@ import (
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/log/model/entity"
"gitea.com/red-future/common/redis"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gerror"
@@ -30,6 +29,15 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
// Redis 数据缓存 Key 常量
const (
CleanList = "list:tenantId-%v:collection-%s:*" // 清理列表Key
CleanCount = "count:tenantId-%v:collection-%s:*" // 清理计数Key
List = "list:tenantId-%v:collection-%s:filter:%s:options:%s" // 列表查询Key
Count = "count:tenantId-%v:collection-%s:filter:%s" // 计数查询Key
One = "one:tenantId-%v:collection-%s:filter:%s" // 单条查询Key
)
// =============================================================================
// 向后兼容的MongoDB结构体
// =============================================================================
@@ -175,10 +183,10 @@ func (m *mongoDB) Count(ctx context.Context, filter bson.M, collection string) (
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.Count, user.TenantId, collection, filterKey)
redisKey := fmt.Sprintf(Count, user.TenantId, collection, filterKey)
if !m.noCache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient().Get(ctx, redisKey)
resultStr, err = g.Redis().Get(ctx, redisKey)
if err != nil {
return
}
@@ -193,7 +201,7 @@ func (m *mongoDB) Count(ctx context.Context, filter bson.M, collection string) (
}
count, err = db.Collection(collection).CountDocuments(ctx, filter)
if !m.noCache {
err = redis.RedisClient().SetEX(ctx, redisKey, count, int64(time.Hour))
err = g.Redis().SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil {
return
}
@@ -221,10 +229,10 @@ func (m *mongoDB) Find(ctx context.Context, filter bson.M, result interface{}, c
}
filterKey := fmt.Sprintf("%+v", filter)
optionsKey := fmt.Sprintf("%+v%+v", page, orderBy)
redisKey := fmt.Sprintf(redis.List, user.TenantId, collection, filterKey, optionsKey)
redisKey := fmt.Sprintf(List, user.TenantId, collection, filterKey, optionsKey)
if !m.noCache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient().Get(ctx, redisKey)
resultStr, err = g.Redis().Get(ctx, redisKey)
if err != nil {
return
}
@@ -284,7 +292,7 @@ func (m *mongoDB) Find(ctx context.Context, filter bson.M, result interface{}, c
return
}
if !m.noCache {
err = redis.RedisClient().SetEX(ctx, redisKey, result, int64(time.Hour))
err = g.Redis().SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return
}
@@ -313,10 +321,10 @@ func (m *mongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.One, user.TenantId, collection, filterKey)
redisKey := fmt.Sprintf(One, user.TenantId, collection, filterKey)
if !m.noCache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient().Get(ctx, redisKey)
resultStr, err = g.Redis().Get(ctx, redisKey)
if err != nil {
return
}
@@ -338,7 +346,7 @@ func (m *mongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}
err = nil
}
if !m.noCache {
err = redis.RedisClient().SetEX(ctx, redisKey, result, int64(time.Hour))
err = g.Redis().SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
@@ -358,24 +366,24 @@ func (m *mongoDB) getDeletedData(ctx context.Context, filter bson.M, collection
}
func (m *mongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) {
listKeys := fmt.Sprintf(redis.CleanList, tenantId, collection)
keys, err := redis.RedisClient().Keys(ctx, listKeys)
listKeys := fmt.Sprintf(CleanList, tenantId, collection)
keys, err := g.Redis().Keys(ctx, listKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient().Del(ctx, key)
_, err = g.Redis().Del(ctx, key)
if err != nil {
return
}
}
countKeys := fmt.Sprintf(redis.CleanCount, tenantId, collection)
keys, err = redis.RedisClient().Keys(ctx, countKeys)
countKeys := fmt.Sprintf(CleanCount, tenantId, collection)
keys, err = g.Redis().Keys(ctx, countKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient().Del(ctx, key)
_, err = g.Redis().Del(ctx, key)
if err != nil {
return
}
@@ -383,8 +391,8 @@ func (m *mongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interf
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
oneKey := fmt.Sprintf(redis.One, tenantId, collection, filterKey)
_, err = redis.RedisClient().Del(ctx, oneKey)
oneKey := fmt.Sprintf(One, tenantId, collection, filterKey)
_, err = g.Redis().Del(ctx, oneKey)
if err != nil {
return
}
@@ -422,10 +430,20 @@ func (m *mongoDB) log(ctx context.Context, ids []bson.ObjectID, filter bson.M, c
log.CreatedAt = now
log.UpdatedAt = now
log.TenantId = tenantId
// 使用新的 context 进行 Redis 操作
if _, err := redis.AddToStream(ctx, LogRedisKey, log); err != nil {
// 将结构体转换为 map
values := gconv.Map(log)
// XADD streamKey * field1 value1 field2 value2 ...
args := make([]interface{}, 0, len(values)*2+2)
args = append(args, LogRedisKey, "*") // "*" 自动生成ID
for key, val := range values {
args = append(args, key, val)
}
_, err := g.Redis().Do(ctx, "XADD", args...)
if err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
return
}

View File

@@ -1,167 +0,0 @@
package message
import (
"context"
"fmt"
"sync"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/nats-io/nats.go"
)
var (
muNats sync.RWMutex
natsConns map[string]*nats.Conn // key: 数据源名称, value: NATS 连接
natsJS map[string]nats.JetStreamContext // key: 数据源名称, value: JetStream 上下文
)
func init() {
natsConns = make(map[string]*nats.Conn)
natsJS = make(map[string]nats.JetStreamContext)
}
// natsConnect 建立 NATS 连接
func natsConnect(ctx context.Context, name string) error {
if g.Cfg().MustGet(ctx, "nats").IsEmpty() {
g.Log().Errorf(ctx, "❌ NATS 配置不存在")
return fmt.Errorf("NATS Configuration does not exist")
}
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
g.Log().Infof(ctx, "🔔 NATS [%s] 开始创建连接", dsName)
muNats.Lock()
defer muNats.Unlock()
// 安全地关闭旧连接(仅针对该数据源)
if oldConn, exists := natsConns[dsName]; exists && oldConn != nil && !oldConn.IsClosed() {
oldConn.Close()
delete(natsConns, dsName)
delete(natsJS, dsName)
}
// 从配置文件读取 NATS 地址
natsURL := g.Cfg().MustGet(ctx, fmt.Sprintf("nats.%s.url", dsName)).String()
if natsURL == "" {
// 默认使用本地地址
natsURL = nats.DefaultURL
}
// 连接选项配置
opts := []nats.Option{
nats.Name(fmt.Sprintf("goframe-nats-client-%s", dsName)),
nats.NoReconnect(),
nats.PingInterval(10 * time.Second),
nats.MaxPingsOutstanding(5),
nats.ClosedHandler(func(nc *nats.Conn) {
g.Log().Infof(ctx, "NATS [%s] 连接已关闭: %s", dsName, nc.ConnectedUrl())
}),
nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) {
g.Log().Errorf(ctx, "❌ NATS [%s] 错误: %v", dsName, err)
}),
}
newConn, err := nats.Connect(natsURL, opts...)
if err != nil {
g.Log().Errorf(ctx, "❌ NATS [%s] 连接失败: %v", dsName, err)
return err
}
// 等待连接就绪
if newConn.Status() != nats.CONNECTED {
select {
case <-time.After(5 * time.Second):
// 连接超时,清理资源
newConn.Close()
g.Log().Errorf(ctx, "❌ NATS [%s] 连接超时", dsName)
return fmt.Errorf("NATS 连接超时")
case <-newConn.StatusChanged(nats.CONNECTED):
// 连接成功
g.Log().Infof(ctx, "✅ NATS [%s] 连接成功: %s", dsName, newConn.ConnectedUrl())
case <-ctx.Done():
// 外部上下文被取消,清理资源
newConn.Close()
g.Log().Errorf(ctx, "NATS [%s] 连接被取消: %v", dsName, ctx.Err())
return fmt.Errorf("NATS 连接被取消: %w", ctx.Err())
}
}
// 创建 JetStream 实例
newJS, err := newConn.JetStream(nats.MaxWait(10 * time.Second))
if err != nil {
// 创建 JetStream 失败,清理连接
newConn.Close()
g.Log().Errorf(ctx, "❌ NATS [%s] 创建 JetStream 失败: %v", dsName, err)
return err
}
// 保存连接和 JetStream 上下文
natsConns[dsName] = newConn
natsJS[dsName] = newJS
return nil
}
// natsPing 检测 NATS 连接状态
func natsPing(ctx context.Context, name string) bool {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muNats.RLock()
defer muNats.RUnlock()
nc, exists := natsConns[dsName]
if !exists || nc == nil || nc.IsClosed() || nc.Status() != nats.CONNECTED {
g.Log().Errorf(ctx, "❌ NATS [%s] 连接已关闭或不可用", dsName)
return false
}
g.Log().Infof(ctx, "📊 NATS [%s] 连接正常: %s", dsName, nc.ConnectedUrl())
return true
}
// natsClose 关闭 NATS 连接
func natsClose(ctx context.Context, name string) error {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muNats.Lock()
defer muNats.Unlock()
if nc, exists := natsConns[dsName]; exists && nc != nil && !nc.IsClosed() {
nc.Close()
}
delete(natsConns, dsName)
delete(natsJS, dsName)
g.Log().Infof(ctx, "✅ NATS [%s] 连接已关闭", dsName)
return nil
}
// getNatsConn 获取 NATS 连接(内部使用)
func getNatsConn(name string) *nats.Conn {
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
return natsConns[dsName]
}
// getNatsJS 获取 JetStream 上下文(内部使用)
func getNatsJS(name string) nats.JetStreamContext {
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
return natsJS[dsName]
}

View File

@@ -1,164 +0,0 @@
package message
import (
"context"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
amqp "github.com/rabbitmq/amqp091-go"
"sync"
)
var (
muRabbitMQ sync.RWMutex
rabbitmqConns map[string]*amqp.Connection
rabbitmqChannels map[string]*amqp.Channel
)
func init() {
rabbitmqConns = make(map[string]*amqp.Connection)
rabbitmqChannels = make(map[string]*amqp.Channel)
}
// rabbitmqConnect 建立 RabbitMQ 连接
func rabbitmqConnect(ctx context.Context, name string) error {
if g.Cfg().MustGet(ctx, "rabbitmq").IsEmpty() {
g.Log().Errorf(ctx, "❌ RabbitMQ 配置不存在")
return fmt.Errorf("RabbitMQ Configuration does not exist")
}
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
g.Log().Infof(ctx, "🔔 RabbitMQ [%s] 开始创建连接", dsName)
muRabbitMQ.Lock()
defer muRabbitMQ.Unlock()
// 安全地关闭旧连接(仅针对该数据源)
if oldConn, exists := rabbitmqConns[dsName]; exists && oldConn != nil && !oldConn.IsClosed() {
oldConn.Close()
}
if oldChannel, exists := rabbitmqChannels[dsName]; exists && oldChannel != nil && !oldChannel.IsClosed() {
oldChannel.Close()
}
delete(rabbitmqConns, dsName)
delete(rabbitmqChannels, dsName)
// 从配置文件读取 RabbitMQ 配置
host := g.Cfg().MustGet(ctx, fmt.Sprintf("rabbitmq.%s.host", dsName)).String()
port := g.Cfg().MustGet(ctx, fmt.Sprintf("rabbitmq.%s.port", dsName)).Int()
username := g.Cfg().MustGet(ctx, fmt.Sprintf("rabbitmq.%s.username", dsName)).String()
password := g.Cfg().MustGet(ctx, fmt.Sprintf("rabbitmq.%s.password", dsName)).String()
vHost := g.Cfg().MustGet(ctx, fmt.Sprintf("rabbitmq.%s.vhost", dsName), "/").String()
if g.IsEmpty(host) {
return fmt.Errorf("❌ RabbitMQ 配置错误: host 不能为空 (数据源: %s)", dsName)
}
if g.IsEmpty(port) {
return fmt.Errorf("❌ RabbitMQ 配置错误: port 不能为空 (数据源: %s)", dsName)
}
if g.IsEmpty(username) {
return fmt.Errorf("❌ RabbitMQ 配置错误: username 不能为空 (数据源: %s)", dsName)
}
if g.IsEmpty(password) {
return fmt.Errorf("❌ RabbitMQ 配置错误: password 不能为空 (数据源: %s)", dsName)
}
// 构建连接 URL
url := "amqp://" + username + ":" + password + "@" + host + ":" + gconv.String(port) + "/" + vHost
// 创建连接
newConn, err := amqp.Dial(url)
if err != nil {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] 连接失败: %v", dsName, err)
return err
}
// 创建 Channel
newChannel, err := newConn.Channel()
if err != nil {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] 创建 Channel 失败: %v", dsName, err)
newConn.Close()
return err
}
// 保存连接和 Channel
rabbitmqConns[dsName] = newConn
rabbitmqChannels[dsName] = newChannel
g.Log().Infof(ctx, "✅ RabbitMQ [%s] 连接成功", dsName)
return nil
}
// rabbitmqPing 检测 RabbitMQ 连接状态
func rabbitmqPing(ctx context.Context, name string) bool {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muRabbitMQ.RLock()
defer muRabbitMQ.RUnlock()
conn, exists := rabbitmqConns[dsName]
channel, channelExists := rabbitmqChannels[dsName]
if !exists || conn == nil || conn.IsClosed() || !channelExists || channel == nil || channel.IsClosed() {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] 连接已关闭或不可用", dsName)
return false
}
g.Log().Infof(ctx, "📊 RabbitMQ [%s] 连接正常", dsName)
return true
}
// rabbitmqClose 关闭 RabbitMQ 连接
func rabbitmqClose(ctx context.Context, name string) error {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muRabbitMQ.Lock()
defer muRabbitMQ.Unlock()
var lastErr error
if channel, exists := rabbitmqChannels[dsName]; exists && channel != nil && !channel.IsClosed() {
if err := channel.Close(); err != nil {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] 关闭 Channel 失败: %v", dsName, err)
lastErr = err
}
}
delete(rabbitmqChannels, dsName)
if conn, exists := rabbitmqConns[dsName]; exists && conn != nil && !conn.IsClosed() {
if err := conn.Close(); err != nil {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] 关闭连接失败: %v", dsName, err)
lastErr = err
}
}
delete(rabbitmqConns, dsName)
g.Log().Infof(ctx, "✅ RabbitMQ [%s] 连接已关闭", dsName)
return lastErr
}
// getRabbitMQConn 获取 RabbitMQ 连接(内部使用)
func getRabbitMQConn(name string) *amqp.Connection {
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
return rabbitmqConns[dsName]
}
// getRabbitMQChannel 获取 RabbitMQ Channel内部使用
func getRabbitMQChannel(name string) *amqp.Channel {
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
return rabbitmqChannels[dsName]
}

View File

@@ -1,198 +0,0 @@
// =============================================================================
// Redis 连接管理
// 负责 Redis 的连接、重连、健康检查和优雅关闭
// =============================================================================
package message
import (
"context"
"fmt"
"sync"
"time"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
)
var (
muRedis sync.RWMutex
redisConns map[string]*gredis.Redis
redisConfigs map[string]*gredis.Config
)
func init() {
redisConns = make(map[string]*gredis.Redis)
redisConfigs = make(map[string]*gredis.Config)
}
// redisConnect 建立 Redis 连接
// name: 数据源名称,如果为空则使用默认数据源
func redisConnect(ctx context.Context, name string) error {
if g.Cfg().MustGet(ctx, "redis").IsEmpty() {
g.Log().Errorf(ctx, "❌ Redis 配置不存在")
return fmt.Errorf("redis Configuration does not exist")
}
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
g.Log().Infof(ctx, "🔔 Redis [%s] 开始创建连接", dsName)
muRedis.Lock()
defer muRedis.Unlock()
// 安全地关闭旧连接(仅针对该数据源)
if oldRedis, exists := redisConns[dsName]; exists && oldRedis != nil {
oldRedis.Close(ctx)
delete(redisConns, dsName)
}
// 从配置文件读取 Redis 配置
redisAddr := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.address", dsName)).String()
if g.IsEmpty(redisAddr) {
g.Log().Errorf(ctx, "❌ Redis 配置错误: address 不能为空 (数据源: %s)", dsName)
return fmt.Errorf("❌ Redis 配置错误: address 不能为空 (数据源: %s)", dsName)
}
redisDB := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.db", dsName)).Int()
if redisDB < 0 || redisDB > 15 {
g.Log().Errorf(ctx, "❌ Redis 配置错误: db 必须在 0-15 之间 (当前值: %d)", redisDB)
return fmt.Errorf("❌ Redis 配置错误: db 必须在 0-15 之间 (当前值: %d)", redisDB)
}
idleTimeout := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.idleTimeout", dsName)).String()
redisIdleTimeout, err := time.ParseDuration(idleTimeout)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis idleTimeout 格式错误: %v", err)
return err
}
maxConnLifetime := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.maxConnLifetime", dsName)).String()
redisMaxConnLifetime, err := time.ParseDuration(maxConnLifetime)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis maxConnLifetime 格式错误: %v", err)
return err
}
waitTimeout := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.waitTimeout", dsName)).String()
redisWaitTimeout, err := time.ParseDuration(waitTimeout)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis waitTimeout 格式错误: %v", err)
return err
}
dialTimeout := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.dialTimeout", dsName)).String()
redisDialTimeout, err := time.ParseDuration(dialTimeout)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis dialTimeout 格式错误: %v", err)
return err
}
readTimeout := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.readTimeout", dsName)).String()
redisReadTimeout, err := time.ParseDuration(readTimeout)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis readTimeout 格式错误: %v", err)
return err
}
writeTimeout := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.writeTimeout", dsName)).String()
redisWriteTimeout, err := time.ParseDuration(writeTimeout)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis writeTimeout 格式错误: %v", err)
return err
}
maxActive := g.Cfg().MustGet(ctx, fmt.Sprintf("redis.%s.maxActive", dsName)).Int()
if g.IsEmpty(maxActive) {
g.Log().Errorf(ctx, "❌ Redis maxActive 配置错误: %v", maxActive)
return fmt.Errorf("❌ Redis maxActive 配置错误")
}
// 构建 GoFrame Redis 配置
redisConfig := &gredis.Config{
Address: redisAddr,
Db: redisDB,
IdleTimeout: redisIdleTimeout,
MaxConnLifetime: redisMaxConnLifetime,
WaitTimeout: redisWaitTimeout,
DialTimeout: redisDialTimeout,
ReadTimeout: redisReadTimeout,
WriteTimeout: redisWriteTimeout,
MaxActive: maxActive,
}
redisConfigs[dsName] = redisConfig
// 使用 GoFrame 的 Redis 连接
newRedis, err := gredis.New(redisConfig)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis [%s] 连接失败: %v", dsName, err)
return err
}
// 测试连接(直接调用避免死锁)
_, err = newRedis.Do(ctx, "PING")
if err != nil {
g.Log().Errorf(ctx, "❌ Redis [%s] 连接失败: ping 失败 - %v", dsName, err)
_ = newRedis.Close(ctx)
return err
}
redisConns[dsName] = newRedis
g.Log().Infof(ctx, "✅ Redis [%s] 连接成功: %s (DB: %d)", dsName, redisAddr, redisDB)
return nil
}
// redisPing 检测 Redis 连接状态(带超时保护)
func redisPing(ctx context.Context, name string) bool {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muRedis.RLock()
defer muRedis.RUnlock()
rc, exists := redisConns[dsName]
if !exists || rc == nil {
g.Log().Errorf(ctx, "❌ Redis [%s] 连接未建立", dsName)
return false
}
// 创建带超时的子上下文,避免死锁
timeoutCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
_, err := rc.Do(timeoutCtx, "PING")
if err != nil {
g.Log().Errorf(ctx, "❌ Redis [%s] ping 失败: %v", dsName, err)
return false
}
g.Log().Infof(ctx, "📊 Redis [%s] 连接正常", dsName)
return true
}
// redisClose 关闭 Redis 连接
func redisClose(ctx context.Context, name string) error {
// 确定数据源名称
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
muRedis.Lock()
defer muRedis.Unlock()
if rc, exists := redisConns[dsName]; exists && rc != nil {
if err := rc.Close(ctx); err != nil {
g.Log().Errorf(ctx, "❌ Redis [%s] 关闭失败: %v", dsName, err)
return err
}
delete(redisConns, dsName)
}
g.Log().Infof(ctx, "✅ Redis [%s] 连接已关闭", dsName)
return nil
}
// getRedisConn 获取 Redis 连接(内部使用)
func getRedisConn(name string) *gredis.Redis {
dsName := "default"
if !g.IsEmpty(name) {
dsName = name
}
return redisConns[dsName]
}

View File

@@ -1,32 +0,0 @@
package message
import "context"
type messagePublishConfig interface {
GetPublishMsgType()
}
type messagePublishDelayConfig interface {
GetPublishDelayMsgType()
}
type messageSubscribeConfig interface {
GetSubscribeMsgType()
}
// messageUtil 消息队列公共配置接口
// 只暴露核心的发布/订阅功能,配置访问器方法不需要在公共接口中
type messageUtil interface {
// Publish 发布消息
Publish(ctx context.Context, msg messagePublishConfig) error
// PublishDelay 发布延迟消息
PublishDelay(ctx context.Context, msg messagePublishDelayConfig) error
// Subscribe 订阅消息
Subscribe(ctx context.Context, msg messageSubscribeConfig) error
// Ping 检测连接状态
Ping(ctx context.Context) bool
// Connect 连接
Connect(ctx context.Context) error
// Close 关闭连接
Close(ctx context.Context) error
}

View File

@@ -1,114 +0,0 @@
package message
import (
"context"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
"sync"
)
// MessageType 消息队列类型
type messageType string
const (
// MessageRedis Redis 消息队列
MessageRedis messageType = "redis"
// MessageRabbitMQ RabbitMQ 消息队列
MessageRabbitMQ messageType = "rabbitmq"
// MessageNATS NATS 消息队列
MessageNATS messageType = "nats"
)
// configFactory 消息队列配置工厂函数类型
type configFactory func() messageUtil
// PluginManager 消息队列插件管理器
type pluginManager struct {
mu sync.RWMutex
instances map[messageType]messageUtil // 已连接的插件实例
}
var (
defaultPluginManager = newPluginManager()
)
// newPluginManager 创建插件管理器
func newPluginManager() *pluginManager {
return &pluginManager{
instances: make(map[messageType]messageUtil),
}
}
// register 注册插件(内部方法)
func (m *pluginManager) register(msgType messageType, instance messageUtil) error {
m.mu.Lock()
defer m.mu.Unlock()
m.instances[msgType] = instance
return nil
}
// RegisterPlugin 注册消息队列插件
// 所有插件必须通过此方法注册,自动进行连接检测
// 只有连接成功的插件才会被注册,连接失败的插件不会被注册
// 异步无限重连,只有连接成功了才注册
// name: 数据源名称,用于标识不同的连接实例
func RegisterPlugin(ctx context.Context, name string, msgType messageType, factory configFactory) error {
if factory == nil {
g.Log().Errorf(ctx, "❌ factory cannot be nil")
return fmt.Errorf("factory cannot be nil")
}
// 开启异步连接,无限重试直到成功
go func() {
// 创建实例
instance := factory()
// 创建通知 channel
pluginKey := fmt.Sprintf("%s-%s", msgType, name)
if !instance.Ping(ctx) {
// 使用统一的重连函数
if err := commonConnect(ctx, msgType, name, func(ctx context.Context) error {
return instance.Connect(ctx)
}, func(ctx context.Context) error {
return instance.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", msgType, name, err)
return
}
}
// 连接成功,注册插件
defaultPluginManager.mu.Lock()
defaultPluginManager.instances[messageType(pluginKey)] = instance
defaultPluginManager.mu.Unlock()
g.Log().Infof(ctx, "✅ [%s][%s] 插件注册成功", msgType, name)
}()
return nil
}
// GetMsgPlugin 获取消息队列插件(默认数据源),如果未注册则等待
func GetMsgPlugin(ctx context.Context, msgType messageType) (messageUtil, error) {
return GetMsgPluginWithName(ctx, msgType, "default")
}
// GetMsgPluginWithName 获取指定数据源的消息队列插件,如果未注册则等待直到超时
func GetMsgPluginWithName(ctx context.Context, msgType messageType, name string) (messageUtil, error) {
pluginKey := fmt.Sprintf("%s-%s", msgType, name)
for {
defaultPluginManager.mu.RLock()
instance, ok := defaultPluginManager.instances[messageType(pluginKey)]
defaultPluginManager.mu.RUnlock()
if ok {
return instance, nil
}
// 未注册,等待一段时间后重试
select {
case <-ctx.Done():
return nil, fmt.Errorf("wait for plugin ready canceled: %s with datasource: %s", msgType, name)
default:
time.Sleep(3 * time.Second)
}
}
}

View File

@@ -1,373 +0,0 @@
package message
import (
"context"
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/nats-io/nats.go"
"time"
)
type NatsPublishMsgConfig struct {
QueueName string
Durable bool
Data any
}
type NatsPublishDelayMsgConfig struct {
QueueName string
Durable bool
DelayTime int
Data any
}
type NatsSubscribeMsgConfig struct {
QueueName string
ConsumerName string
Durable bool
DelayTime int
AutoAck bool
PrefetchCount int
HandleFunc func(ctx context.Context, message map[string]interface{}) error
}
func (*NatsPublishMsgConfig) GetPublishMsgType() {
}
func (*NatsPublishDelayMsgConfig) GetPublishDelayMsgType() {
}
func (*NatsSubscribeMsgConfig) GetSubscribeMsgType() {
}
type natsMsg struct {
name string // 数据源名称
}
func init() {
// 注册 Nats 插件(默认数据源)
RegisterPlugin(context.Background(), "default", MessageNATS, func() messageUtil {
return &natsMsg{name: "default"}
})
}
// Connect 连接 NATS
func (c *natsMsg) Connect(ctx context.Context) error {
return natsConnect(ctx, c.name)
}
// Ping 检测 NATS 连接状态
func (c *natsMsg) Ping(ctx context.Context) bool {
return natsPing(ctx, c.name)
}
// Close 关闭 NATS 连接
func (c *natsMsg) Close(ctx context.Context) error {
return natsClose(ctx, c.name)
}
// Publish 发布消息
func (c *natsMsg) Publish(ctx context.Context, msgConfig messagePublishConfig) error {
cfg, ok := msgConfig.(*NatsPublishMsgConfig)
if !ok {
return fmt.Errorf("无效的 NATS 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("必须提供队列名称")
}
if g.IsEmpty(cfg.Data) {
return fmt.Errorf("必须提供数据")
}
return c.createPublish(ctx, cfg.QueueName, cfg.Durable, 0, cfg.Data)
}
// PublishDelay 发布延迟消息
func (c *natsMsg) PublishDelay(ctx context.Context, msgConfig messagePublishDelayConfig) error {
cfg, ok := msgConfig.(*NatsPublishDelayMsgConfig)
if !ok {
return fmt.Errorf("无效的 NATS 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("必须提供队列名称")
}
if g.IsEmpty(cfg.DelayTime) {
return fmt.Errorf("延迟时间必须大于 0")
}
if g.IsEmpty(cfg.Data) {
return fmt.Errorf("必须提供数据")
}
return c.createPublish(ctx, cfg.QueueName, cfg.Durable, cfg.DelayTime, cfg.Data)
}
// Publish 发布消息
func (c *natsMsg) createPublish(ctx context.Context, subject string, durable bool, delayTime int, data any) error {
delayMsg := delayTime > 0
if err := c.createStream(ctx, subject, durable, delayMsg); err != nil {
return err
}
payload, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化数据失败: %w", err)
}
m := nats.NewMsg(subject)
m.Data = payload // 所有消息都需要设置数据
if delayMsg {
// 使用 @at 指定具体延迟时间,而不是 @every 重复执行
futureTime := time.Now().Add(time.Duration(delayTime) * time.Second).Format(time.RFC3339Nano)
m.Header.Set("Nats-Schedule", fmt.Sprintf("@at %s", futureTime))
m.Subject = subject + ".schedule"
m.Header.Set("Nats-Schedule-Target", subject)
g.Log().Infof(ctx, "📅 NATS 延迟消息配置: DelayTime=%ds, Schedule=@at %s, Header=%s", delayTime, futureTime, m.Header)
}
// 发布消息到 JetStream
js := getNatsJS(c.name)
if js == nil {
g.Log().Errorf(ctx, "❌ NATS [%s] JetStream 不存在", c.name)
return fmt.Errorf("NATS JetStream 不存在")
}
ack, err := js.PublishMsg(m)
if err != nil {
g.Log().Errorf(ctx, "❌ NATS 发布消息失败: err=%v, Subject=%s", err, m.Subject)
return err
}
g.Log().Infof(ctx, "✅ NATS 发布消息成功: Stream=%v, StreamSeq=%d", ack.Stream, ack.Sequence)
return nil
}
// Subscribe 订阅消息
func (c *natsMsg) Subscribe(ctx context.Context, msgConfig messageSubscribeConfig) error {
cfg, ok := msgConfig.(*NatsSubscribeMsgConfig)
if !ok {
return fmt.Errorf("无效的 NATS 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("必须提供队列名称")
}
if g.IsEmpty(cfg.ConsumerName) {
return fmt.Errorf("必须提供消费者名称")
}
if g.IsEmpty(cfg.HandleFunc) {
return fmt.Errorf("必须提供处理函数")
}
if g.IsEmpty(cfg.PrefetchCount) {
cfg.PrefetchCount = 1
}
return c.createSubscribe(ctx, cfg.QueueName, cfg.ConsumerName, cfg.PrefetchCount, cfg.DelayTime, cfg.AutoAck, cfg.Durable, cfg.HandleFunc)
}
// createSubscribe 内部订阅消息
func (c *natsMsg) createSubscribe(ctx context.Context, subject, consumerName string, prefetchCount, delayTime int, autoAck, durable bool, handler func(ctx context.Context, message map[string]any) error) error {
g.Log().Infof(ctx, "🔔 NATS 开始订阅: QueueName=%s, ConsumerName=%s", subject, consumerName)
// 创建推送订阅的回调函数
msgHandler := func(msg *nats.Msg) {
var data map[string]any
if err := json.Unmarshal(msg.Data, &data); err != nil {
g.Log().Errorf(ctx, "❌ 解析消息失败: %v", err)
return
}
g.Log().Infof(ctx, "📨 收到消息: Subject=%s, Data=%v", msg.Subject, data)
// 处理业务逻辑
if err := handler(ctx, data); err != nil {
g.Log().Errorf(ctx, "❌ 处理消息失败: %v", err)
if !autoAck {
if err := msg.Nak(); err != nil {
g.Log().Errorf(ctx, "❌ Nak 失败: %v", err)
return
}
return
}
} else {
g.Log().Infof(ctx, "✅ 处理消息成功")
}
if err := msg.Ack(); err != nil {
g.Log().Errorf(ctx, "❌ Ack 失败: %v", err)
}
}
delayMsg := delayTime > 0
// 创建流
if err := c.createStream(ctx, subject, durable, delayMsg); err != nil {
return err
}
// 获取 JetStream 上下文
js := getNatsJS(c.name)
if js == nil {
g.Log().Errorf(ctx, "❌ NATS [%s] JetStream 不存在", c.name)
return fmt.Errorf("NATS JetStream 不存在")
}
// 创建推送订阅
var sub *nats.Subscription
var err error
// 配置订阅选项 - 使用 DeliverSubject 创建 Push Consumer
subOpts := []nats.SubOpt{
nats.Durable(consumerName),
nats.MaxAckPending(prefetchCount),
nats.DeliverSubject(consumerName),
}
if !autoAck {
subOpts = append(subOpts, nats.ManualAck())
}
// 使用 Subscribe 创建推送订阅
sub, err = js.Subscribe(subject, msgHandler, subOpts...)
if err != nil {
g.Log().Errorf(ctx, "创建推送订阅失败: %v", err)
return err
}
g.Log().Infof(ctx, "✅ NATS 推送订阅成功: Consumer=%s", consumerName)
// 启动后台 goroutine 监听上下文取消,用于清理订阅
go func() {
<-ctx.Done()
g.Log().Infof(ctx, "订阅上下文取消,取消订阅")
if err := sub.Unsubscribe(); err != nil {
return
}
}()
return nil
}
// createStream 内部创建消费组
func (c *natsMsg) createStream(ctx context.Context, subject string, durable, delayMsg bool) error {
streamName, storage := getStreamInfo(durable, delayMsg)
// 构建流配置
// 如果是延迟消息,需要包含两个 subjects:
// 1. subject.schedule - 用于发送调度消息
// 2. subject - 用于实际投递目标
subjects := []string{subject}
if delayMsg {
subjects = []string{subject, subject + ".schedule"}
}
jsConfig := &StreamConfig{
Name: streamName,
Subjects: subjects,
AllowMsgSchedules: delayMsg, // 延迟消息核心开关
Storage: storage,
Discard: DiscardNew, // 达到上限删除旧消息
}
nc := getNatsConn(c.name)
if !c.Ping(ctx) {
// 使用统一的重连函数
if err := commonConnect(ctx, MessageNATS, c.name, func(ctx context.Context) error {
return c.Connect(ctx)
}, func(ctx context.Context) error {
return c.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", MessageNATS, c.name, err)
return err
}
}
if nc == nil {
g.Log().Errorf(ctx, "❌ NATS [%s] 连接不存在", c.name)
return fmt.Errorf("NATS 连接不存在")
}
err := jsStreamCreate(nc, jsConfig)
if err != nil {
g.Log().Errorf(ctx, "❌ 创建 Stream 失败: err=%v", err)
return err
}
g.Log().Infof(ctx, "✅ 创建 Stream 成功: stream=%s, subjects=%v, allowSchedules=%v", streamName, subjects, delayMsg)
return nil
}
func getStreamInfo(durable, delayMsg bool) (string, StorageType) {
// Stream 不存在,创建新的
streamName := "ordinary_msg_memory"
storage := MemoryStorage
// 延迟消息必须使用 FileStorageNATS 官方要求)
if delayMsg {
if durable {
streamName = "delay_msg_file"
storage = FileStorage
} else {
streamName = "delay_msg_memory"
storage = MemoryStorage
}
} else {
if durable {
streamName = "ordinary_msg_file"
storage = FileStorage
}
}
return streamName, storage
}
const (
// JSApiStreamCreateT is the endpoint to create new streams.
// Will return JSON response.
JSApiStreamCreateT = "$JS.API.STREAM.CREATE.%s"
// JSApiStreamUpdateT is the endpoint to update existing streams.
// Will return JSON response.
JSApiStreamUpdateT = "$JS.API.STREAM.UPDATE.%s"
)
// jsStreamCreate is for sending a stream create for fields that nats.go does not know about yet.
func jsStreamCreate(nc *nats.Conn, cfg *StreamConfig) error {
j, err := json.Marshal(cfg)
if err != nil {
return err
}
msg, err := nc.Request(fmt.Sprintf(JSApiStreamCreateT, cfg.Name), j, time.Second*3)
if err != nil {
return err
}
// 检查 API 响应中的错误
var resp struct {
Error *struct {
Code int `json:"code"`
ErrCode int `json:"err_code"`
Description string `json:"description"`
} `json:"error,omitempty"`
}
if err := json.Unmarshal(msg.Data, &resp); err != nil {
return err
}
if resp.Error != nil {
// 如果 Stream 已存在,尝试更新
if resp.Error.ErrCode == 10058 { // JSStreamNameExistErr
return jsStreamUpdate(nc, cfg)
}
return fmt.Errorf("JS API error: %s", resp.Error.Description)
}
return nil
}
// jsStreamUpdate is for sending a stream create for fields that nats.go does not know about yet.
func jsStreamUpdate(nc *nats.Conn, cfg *StreamConfig) error {
j, err := json.Marshal(cfg)
if err != nil {
return err
}
msg, err := nc.Request(fmt.Sprintf(JSApiStreamUpdateT, cfg.Name), j, time.Second*3)
if err != nil {
return err
}
// 检查 API 响应中的错误
var resp struct {
Error *struct {
Code int `json:"code"`
ErrCode int `json:"err_code"`
Description string `json:"description"`
} `json:"error,omitempty"`
}
if err := json.Unmarshal(msg.Data, &resp); err != nil {
return err
}
if resp.Error != nil {
return fmt.Errorf("JS API error: %s", resp.Error.Description)
}
return nil
}

View File

@@ -1,770 +0,0 @@
package message
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/nats-io/nats.go"
"go.opentelemetry.io/otel/trace"
"reflect"
"sync"
)
// ============ RPC 服务封装 ============
// 以下方法提供了完全抽象的 RPC 调用接口
// 调用方和响应方完全不需要知道底层使用的是 NATS 的发布订阅模式
// RPC 服务注册表
var (
rpcServices map[string]rpcHandler
rpcSubs map[string]*nats.Subscription // 服务名 -> 订阅
rpcServicesMu sync.RWMutex
queueRPCServices map[string]map[string]rpcHandler // queueName -> subject -> handler
queueRPCSubs map[string]map[string]*nats.Subscription // queueName -> serviceName -> 订阅
queueRPCMu sync.RWMutex
// ============ TraceID 主动取消支持 ============
// 全局映射表TraceID -> CancelFunc并发安全
traceCancelMap map[string]context.CancelFunc
traceCancelMu sync.RWMutex
// 取消主题前缀
cancelSubjectPrefix = "ctx.cancel.otel."
// RPC 使用的默认数据源名称
rpcDefaultDatasource = "default"
)
// rpcHandler RPC 处理函数类型
// 实现方只需要关注请求参数和返回值,无需了解底层 NATS 实现
// 返回值可以是任意类型,会被自动序列化为 JSON
type rpcHandler func(ctx context.Context, req []byte) (any, error)
// registerRPCService 注册 RPC 服务(单实例)
// serviceName: 服务名称,调用方通过此名称调用服务
// handler: 服务处理函数,接收请求并返回响应
func registerRPCService(serviceName string, handler rpcHandler) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
rpcServicesMu.Lock()
if rpcServices == nil {
rpcServices = make(map[string]rpcHandler)
}
if rpcSubs == nil {
rpcSubs = make(map[string]*nats.Subscription)
}
// 如果已存在该服务,先取消之前的订阅
if oldSub, exists := rpcSubs[serviceName]; exists {
oldSub.Unsubscribe()
}
rpcServices[serviceName] = handler
rpcServicesMu.Unlock()
// 订阅服务主题
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
subject := fmt.Sprintf("rpc.%s", serviceName)
sub, err := nc.Subscribe(subject, func(msg *nats.Msg) {
// 执行处理函数
executeHandler(handler, msg)
})
if err != nil {
return fmt.Errorf("注册 RPC 服务失败: %w", err)
}
rpcSubs[serviceName] = sub
g.Log().Infof(context.Background(), "✅ RPC 服务已注册: %s", serviceName)
return nil
}
// registerQueueRPCService 注册 RPC 服务(集群模式)
// 多个服务实例注册同一服务时,请求会自动负载均衡
// serviceName: 服务名称
// queueName: 队列组名,同一队列组的实例共享请求
// handler: 服务处理函数
func registerQueueRPCService(serviceName, queueName string, handler rpcHandler) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
queueRPCMu.Lock()
if queueRPCServices == nil {
queueRPCServices = make(map[string]map[string]rpcHandler)
}
if queueRPCSubs == nil {
queueRPCSubs = make(map[string]map[string]*nats.Subscription)
}
if queueRPCServices[queueName] == nil {
queueRPCServices[queueName] = make(map[string]rpcHandler)
}
if queueRPCSubs[queueName] == nil {
queueRPCSubs[queueName] = make(map[string]*nats.Subscription)
}
// 如果已存在该服务,先取消之前的订阅
if oldSub, exists := queueRPCSubs[queueName][serviceName]; exists {
oldSub.Unsubscribe()
}
queueRPCServices[queueName][serviceName] = handler
queueRPCMu.Unlock()
// 订阅服务主题(队列模式)
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
subject := fmt.Sprintf("rpc.%s", serviceName)
sub, err := nc.QueueSubscribe(subject, queueName, func(msg *nats.Msg) {
// 执行处理函数
executeHandler(handler, msg)
})
if err != nil {
return fmt.Errorf("注册队列 RPC 服务失败: %w", err)
}
queueRPCMu.Lock()
queueRPCSubs[queueName][serviceName] = sub
queueRPCMu.Unlock()
g.Log().Infof(context.Background(), "✅ 队列 RPC 服务已注册: %s (队列组: %s)", serviceName, queueName)
return nil
}
// executeHandler 执行 RPC 处理函数
func executeHandler(handler rpcHandler, msg *nats.Msg) {
// 响应
var respData []byte
// 从消息头重建上下文
ctx := headersToContext(context.Background(), msg.Header)
// 提取 TraceID创建可取消的 context
ctx = createCancelContext(ctx, msg.Header.Get(traceIDKey))
// 检查 context 是否已取消(在调用 handler 之前)
select {
case <-ctx.Done():
// context 已取消,返回取消错误
g.Log().Infof(ctx, "RPC 请求已取消traceID: %s", msg.Header.Get(traceIDKey))
// 仍然需要发送响应以避免客户端超时
respData = []byte(`{"_err":"请求已取消"}`)
// 清理取消映射表
cleanupTraceCancel(msg.Header.Get(traceIDKey))
return
default:
}
// 执行业务处理
response, err := handler(ctx, msg.Data)
if err != nil {
// 错误时返回 {"_err": "错误信息"}
if respData, err = json.Marshal(map[string]any{"_err": err.Error()}); err != nil {
g.Log().Errorf(ctx, "RPC 错误响应序列化失败: %v", err)
respData = []byte(`{"_err":"错误响应序列化失败"}`)
}
} else if response == nil {
// 空响应时返回空对象(或 {"_err": ""}
respData = []byte(`{}`)
} else {
// 成功时返回业务数据
if respData, err = json.Marshal(response); err != nil {
g.Log().Errorf(ctx, "RPC 响应序列化失败: %v", err)
respData = []byte(`{"_err":"响应序列化失败"}`)
}
}
// 发送响应(必须执行) 如果客户端用 nc.Request(...) 发送消息 → 双向模式,服务端必须 msg.Respond
if err = msg.Respond(respData); err != nil {
g.Log().Errorf(ctx, "RPC 响应失败: %v", err)
}
// 请求结束,清理取消映射表
cleanupTraceCancel(msg.Header.Get(traceIDKey))
}
// createCancelContext 创建可取消的 context 并注册到取消映射表
// 返回可取消的 context如果 traceID 为空则返回原 context
func createCancelContext(ctx context.Context, traceID string) context.Context {
if g.IsEmpty(traceID) {
return ctx
}
// 创建带取消功能的 context
taskCtx, cancel := context.WithCancel(ctx)
// 注册到取消映射表
traceCancelMu.Lock()
if traceCancelMap == nil {
traceCancelMap = make(map[string]context.CancelFunc)
}
// 如果同一 TraceID 已有 CancelFunc先调用它
if oldCancel, exists := traceCancelMap[traceID]; exists {
oldCancel()
}
traceCancelMap[traceID] = cancel
traceCancelMu.Unlock()
return taskCtx
}
// ============ TraceID 主动取消功能 ============
// 以下函数实现了基于 OpenTelemetry TraceID 的跨进程任务取消机制
// SetupCancelListener 设置取消监听器
// 订阅取消主题,监听取消指令
// 使用示例:
//
// sub, err := nats.SetupCancelListener(ctx)
func setupCancelListener(ctx context.Context) (*nats.Subscription, error) {
if !natsPing(ctx, rpcDefaultDatasource) {
return nil, fmt.Errorf("NATS 未连接")
}
if traceCancelMap == nil {
traceCancelMap = make(map[string]context.CancelFunc)
}
// 修复问题3订阅取消主题格式: ctx.cancel.otel.*
// 使用 * 通配符而不是 >,因为 TraceID 是最后一部分
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return nil, fmt.Errorf("NATS 连接不存在")
}
cancelSubject := cancelSubjectPrefix + "*"
sub, err := nc.Subscribe(cancelSubject, func(msg *nats.Msg) {
// 从主题中解析 TraceID (去除前缀)
prefixLen := len(cancelSubjectPrefix)
if len(msg.Subject) <= prefixLen {
g.Log().Warningf(ctx, "取消消息主题格式错误: %s", msg.Subject)
return
}
traceID := msg.Subject[prefixLen:]
if traceID == "" {
g.Log().Warning(ctx, "取消消息主题缺少 TraceID")
return
}
// 从映射表获取 CancelFunc 并执行取消
traceCancelMu.RLock()
cancel, ok := traceCancelMap[traceID]
traceCancelMu.RUnlock()
if ok {
cancel()
g.Log().Infof(ctx, "📢 取消信号已发送traceID: %s", traceID)
} else {
g.Log().Infof(ctx, "⚠️ 未找到对应的可取消任务traceID: %s", traceID)
}
})
if err != nil {
return nil, fmt.Errorf("设置取消监听器失败: %w", err)
}
g.Log().Infof(ctx, "✅ 取消监听器已设置: %s", cancelSubject)
return sub, nil
}
// publishCancel 发布取消指令
// 向指定 TraceID 发送取消信号
// 使用示例:
//
// err := nats.publishCancel(ctx, traceID)
func publishCancel(ctx context.Context, traceID string) error {
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
if traceID == "" {
return fmt.Errorf("TraceID 不能为空")
}
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
cancelSubject := cancelSubjectPrefix + traceID
err := nc.Publish(cancelSubject, nil)
if err != nil {
return fmt.Errorf("发布取消信号失败: %w", err)
}
g.Log().Infof(ctx, "📤 已发送取消信号traceID: %s主题: %s", traceID, cancelSubject)
return nil
}
// cleanupTraceCancel 清理取消映射表中的条目
// 任务取消/正常结束后必须调用此函数,避免内存泄漏
// 使用示例:
//
// defer nats.cleanupTraceCancel(traceID)
func cleanupTraceCancel(traceID string) {
if traceID == "" {
return
}
traceCancelMu.Lock()
defer traceCancelMu.Unlock()
if _, ok := traceCancelMap[traceID]; ok {
delete(traceCancelMap, traceID)
g.Log().Infof(context.Background(), "✅ 已清理取消映射表traceID: %s", traceID)
}
}
// CallRPC 调用 RPC 服务
// serviceName: 服务名称
// req: 请求数据
// 返回: 响应数据(任意类型)和错误
func CallRPC(ctx context.Context, serviceName string, req any, resp any) (err error) {
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
// 验证 resp 必须是指针类型
respValue := reflect.ValueOf(resp)
if respValue.Kind() != reflect.Ptr {
return fmt.Errorf("resp 参数必须是指针类型(当前类型: %T", resp)
}
// 构建请求体
var reqBody []byte
if !g.IsEmpty(req) {
reqValue := reflect.ValueOf(req)
if !(reqValue.Kind() == reflect.Ptr && reqValue.IsNil()) && !reqValue.IsZero() {
reqData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("序列化请求参数失败: %w", err)
}
reqBody = reqData
}
}
// 检查本地是否有注册的单实例服务,如果有则直接调用(优化性能)
rpcServicesMu.RLock()
if localHandler, exists := rpcServices[serviceName]; exists {
rpcServicesMu.RUnlock()
// 修复问题1本地调用也需要处理取消机制
var traceID string
if traceID, err = getTraceID(ctx); err != nil {
return err
}
// 提取 TraceID创建可取消的 context
cancelCtx := createCancelContext(ctx, traceID)
// 执行本地调用
var response interface{}
if response, err = localHandler(cancelCtx, reqBody); err != nil {
return fmt.Errorf("本地调用 RPC 服务失败 [%s]: %w", serviceName, err)
}
// 请求结束,清理取消映射表
cleanupTraceCancel(traceID)
// 检查是否为错误消息:尝试解析为 map看是否包含 "_err" 字段
var respMap map[string]any
if json.Unmarshal(response.([]byte), &respMap) == nil {
if errMsg, ok := respMap["_err"]; ok {
return fmt.Errorf("%v", errMsg)
}
}
// 正常数据直接返回
// responseMsg.Data 已经是 []byte 类型(来自 msg.Data直接反序列化
if err = json.Unmarshal(response.([]byte), resp); err != nil {
return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, response)
}
return
}
rpcServicesMu.RUnlock()
subject := fmt.Sprintf("rpc.%s", serviceName)
// 创建消息并将上下文元数据写入消息头
msg := nats.NewMsg(subject)
msg.Data = reqBody
headers, err := contextToHeaders(ctx)
if err != nil {
return fmt.Errorf("上下文转换失败: %w", err)
}
msg.Header = headers
// 修复问题5优化 go 协程避免资源泄漏
// 使用 done channel 来确保 goroutine 能正确退出
done := make(chan struct{})
var closeDoneOnce sync.Once
closeDone := func() {
closeDoneOnce.Do(func() {
close(done)
})
}
if msg.Header.Get(traceIDKey) != "" {
go func() {
defer closeDone()
select {
case <-ctx.Done():
// context 被取消时,发送取消信号给服务端
if errors.Is(ctx.Err(), context.Canceled) {
if err := publishCancel(context.Background(), msg.Header.Get(traceIDKey)); err != nil {
g.Log().Errorf(ctx, "发送 RPC 取消信号失败: %v", err)
} else {
g.Log().Infof(ctx, "RPC 调用已取消traceID: %s", msg.Header.Get(traceIDKey))
}
}
case <-done:
// 请求已完成,无需发送取消信号
return
}
}()
}
// 发送请求
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
responseMsg, err := nc.RequestMsgWithContext(ctx, msg)
// 关闭 done channel通知 goroutine 退出
closeDone()
if err != nil {
return fmt.Errorf("调用 RPC 服务失败 [%s]: %w", serviceName, err)
}
if responseMsg == nil {
return fmt.Errorf("RPC 响应为空 [%s]", serviceName)
}
// 解析响应
if len(responseMsg.Data) > 0 {
// 检查是否为错误消息:尝试解析为 map看是否包含 "_err" 字段
var respMap map[string]any
if json.Unmarshal(responseMsg.Data, &respMap) == nil {
if errMsg, ok := respMap["_err"]; ok {
return fmt.Errorf("%v", errMsg)
}
}
// 正常数据直接返回
// responseMsg.Data 已经是 []byte 类型(来自 msg.Data直接反序列化
if err = json.Unmarshal(responseMsg.Data, resp); err != nil {
return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, responseMsg.Data)
}
}
return
}
// RegisterServiceOption 注册选项类型
type registerServiceOption func(*registerServiceConfig)
type registerServiceConfig struct {
queueName string // 队列组名(用于集群模式)
excludeMethods []string
}
// WithQueueGroup 设置队列组名(集群模式)
func WithQueueGroup(queueName string) registerServiceOption {
return func(cfg *registerServiceConfig) {
cfg.queueName = queueName
}
}
// WithExcludeMethods 排除不需要注册的方法
func WithExcludeMethods(methods ...string) registerServiceOption {
return func(cfg *registerServiceConfig) {
cfg.excludeMethods = append(cfg.excludeMethods, methods...)
}
}
// AutoRegisterServices 自动注册多个服务的所有公开方法
// serviceInstances: map[包名]service实例如 map[string]interface{}{"user": userService, "order": orderService}
// options: 注册选项(可选)
// 示例:
//
// AutoRegisterServices(map[string]interface{}{
// "user": userService,
// "order": orderService,
// })
// 或
// AutoRegisterServices(map[string]interface{}{
// "order": orderService,
// }, WithQueueGroup("order-group"))
func AutoRegisterServices(ctx context.Context, serviceInstances map[string]interface{}, options ...registerServiceOption) error {
// 先注册 RPC 服务(如果 NATS 不可用则记录警告但不阻塞启动)
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接RPC 服务未注册")
}
if len(serviceInstances) == 0 {
return fmt.Errorf("service 实例列表不能为空")
}
totalRegistered := 0
// 遍历每个 service 实例
for pkgName, serviceInstance := range serviceInstances {
// 注册服务
err := registerService(serviceInstance, pkgName, options...)
if err != nil {
g.Log().Errorf(ctx, "注册 %s 服务失败: %v", pkgName, err)
continue
}
totalRegistered++
g.Log().Infof(ctx, "✅ %s 服务已自动注册", pkgName)
}
if totalRegistered == 0 {
return fmt.Errorf("未能注册任何服务")
}
// 设置取消监听器(监听基于 TraceID 的取消请求)
if _, err := setupCancelListener(ctx); err != nil {
g.Log().Errorf(ctx, "设置取消监听器失败: %v", err)
} else {
g.Log().Infof(ctx, "✅ 取消监听器已自动设置")
}
g.Log().Infof(ctx, "✅ 共自动注册了 %d 个服务", totalRegistered)
return nil
}
// registerService 注册单个服务的所有公开方法(内部函数)
func registerService(service interface{}, serviceNamePrefix string, options ...registerServiceOption) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
// 应用选项
cfg := &registerServiceConfig{}
for _, opt := range options {
opt(cfg)
}
// 创建排除方法集合
excludeSet := make(map[string]struct{})
for _, method := range cfg.excludeMethods {
excludeSet[method] = struct{}{}
}
// 获取 service 的类型
serviceType := reflect.TypeOf(service)
// 遍历所有方法
registeredCount := 0
for i := 0; i < serviceType.NumMethod(); i++ {
method := serviceType.Method(i)
// 只注册导出方法(首字母大写)
if !method.IsExported() {
continue
}
// 排除指定的方法
if _, exists := excludeSet[method.Name]; exists {
continue
}
// 检查方法签名:必须是 func(ctx context.Context, request) (response, error)
// 注意method.Type.NumIn() 包含接收者,所以实际参数数量需要减去 1
// 要求:接收者 + context.Context + request总共3个参数
if method.Type.NumIn() != 3 {
g.Log().Warningf(context.Background(), "方法 %s 必须有2个参数context.Context 和请求参数),跳过注册", method.Name)
continue
}
// 第一个参数(接收者之后的第一个参数)必须是 context.Context
// method.Type.In(0) 是接收者method.Type.In(1) 才是第一个参数
if !method.Type.In(1).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
g.Log().Warningf(context.Background(), "方法 %s 的第一个参数必须是 context.Context跳过注册", method.Name)
continue
}
// 第二个参数必须是结构体指针或数组
reqType := method.Type.In(2)
if reqType.Kind() != reflect.Ptr && reqType.Kind() != reflect.Slice && reqType.Kind() != reflect.Array {
g.Log().Warningf(context.Background(), "方法 %s 的第二个参数必须是结构体指针或数组,跳过注册", method.Name)
continue
}
// 返回值必须是 (result, error)即2个返回值
if method.Type.NumOut() != 2 {
g.Log().Warningf(context.Background(), "方法 %s 必须有2个返回值result 和 error跳过注册", method.Name)
continue
}
// 最后一个返回值必须是 error
if !method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
g.Log().Warningf(context.Background(), "方法 %s 的最后一个返回值必须是 error跳过注册", method.Name)
continue
}
// 生成服务名称:前缀.方法名(保持原始方法名)
serviceName := fmt.Sprintf("%s.%s", serviceNamePrefix, method.Name)
// 创建 RPC handler
handler := func(ctx context.Context, req []byte) (any, error) {
// 准备方法调用参数
// args[0] 是接收者, args[1] 是 ctx, args[2] 是请求参数
args := make([]reflect.Value, 3)
args[0] = reflect.ValueOf(service) // 接收者
args[1] = reflect.ValueOf(ctx) // context.Context
// 解析请求参数
if len(req) > 0 {
reqValuePtr := reflect.New(reqType)
// 解析 JSON
if err := json.Unmarshal(req, reqValuePtr.Interface()); err != nil {
// 根据参数类型提供更友好的错误提示
var typeHint string
if reqType.Kind() == reflect.Ptr {
typeHint = fmt.Sprintf("(期望类型: %s", reqType.Elem().Name())
} else { // reflect.Slice 或 reflect.Array
typeHint = fmt.Sprintf("(期望类型: %s请确保客户端传递的是JSON数组格式", reqType.String())
}
return nil, fmt.Errorf("解析请求参数失败%s: %w", typeHint, err)
}
args[2] = reqValuePtr.Elem()
} else {
// 请求为空,创建零值
args[2] = reflect.Zero(method.Type.In(2))
}
// 调用方法
results := method.Func.Call(args)
// 处理返回值
var result any
if len(results) == 1 {
// 只有 error
if !results[0].IsNil() {
err = results[0].Interface().(error)
}
} else if len(results) == 2 {
// (result, error)
result = results[0].Interface()
if !results[1].IsNil() {
err = results[1].Interface().(error)
}
}
if err != nil {
return nil, err
}
return result, nil
}
// 注册 RPC 服务
var err error
if cfg.queueName != "" {
err = registerQueueRPCService(serviceName, cfg.queueName, handler)
} else {
err = registerRPCService(serviceName, handler)
}
if err != nil {
g.Log().Errorf(context.Background(), "注册服务 %s 失败: %v", serviceName, err)
continue
}
registeredCount++
g.Log().Infof(context.Background(), "✅ 已自动注册 RPC 服务: %s -> %s", serviceName, method.Name)
}
if registeredCount == 0 {
g.Log().Warningf(context.Background(), "未注册任何方法,请检查 %v 的方法签名", serviceNamePrefix)
return fmt.Errorf("未找到可注册的方法")
}
g.Log().Infof(context.Background(), "✅ Service %v 共注册了 %d 个 RPC 方法", serviceNamePrefix, registeredCount)
return nil
}
// ============ 上下文元数据工具函数 ============
// 以下函数用于在 context 和 NATS 消息头之间互转元数据
// 定义常见的上下文元数据 key私有
const (
traceIDKey = "trace_id"
tokenKey = "token"
)
func getTraceID(ctx context.Context) (traceID string, err error) {
// 提取 traceId首先尝试从 OpenTelemetry Span 中提取,从 context 中提取 TraceID
span := trace.SpanFromContext(ctx)
if span != nil && span.SpanContext().HasTraceID() {
traceID = span.SpanContext().TraceID().String()
} else if tid := ctx.Value(traceIDKey); tid != nil {
traceID = fmt.Sprintf("%v", tid)
}
if traceID == "" {
return traceID, fmt.Errorf("context 中没有 TraceID")
}
return
}
// contextToHeaders 将 context 中的元数据转换为 NATS 消息头
// 支持提取 user_id、tenant_id、trace_id、token 等常见字段
func contextToHeaders(ctx context.Context) (nats.Header, error) {
headers := make(nats.Header)
// 提取 traceId首先尝试从 OpenTelemetry Span 中提取
if traceID, err := getTraceID(ctx); err != nil {
return headers, err
} else {
headers.Set(traceIDKey, traceID)
}
// 提取 token优先级context value > HTTP Authorization header
token := ""
if t := ctx.Value(tokenKey); t != nil {
token = fmt.Sprintf("%v", t)
} else if r := g.RequestFromCtx(ctx); r != nil {
// 从 HTTP 请求的 Authorization header 中提取 token
auth := r.GetHeader("Authorization")
if auth != "" {
// 移除 "Bearer " 前缀
if len(auth) > 7 && auth[:7] == "Bearer " {
token = auth[7:]
} else {
token = auth
}
}
}
if token != "" {
headers.Set(tokenKey, token)
}
return headers, nil
}
// headersToContext 从 NATS 消息头重建 context
// 支持还原 user_id、tenant_id、trace_id、token 等字段
func headersToContext(ctx context.Context, headers nats.Header) context.Context {
if headers == nil {
return ctx
}
// 恢复 trace_id
if traceID := headers.Get(traceIDKey); traceID != "" {
ctx = context.WithValue(ctx, traceIDKey, traceID)
}
// 恢复 token
if token := headers.Get(tokenKey); token != "" {
ctx = context.WithValue(ctx, tokenKey, token)
}
return ctx
}

View File

@@ -1,311 +0,0 @@
package message
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
amqp "github.com/rabbitmq/amqp091-go"
)
type RabbitMQPublishMsgConfig struct {
QueueName string
Durable bool
Data any
}
type RabbitMQPublishDelayMsgConfig struct {
QueueName string
Durable bool
DelayTime int
Data any
}
type RabbitMQSubscribeMsgConfig struct {
QueueName string
ConsumerName string
AutoAck bool
PrefetchCount int
HandleFunc func(ctx context.Context, message map[string]interface{}) error
}
func (*RabbitMQPublishMsgConfig) GetPublishMsgType() {
}
func (*RabbitMQPublishDelayMsgConfig) GetPublishDelayMsgType() {}
func (*RabbitMQSubscribeMsgConfig) GetSubscribeMsgType() {
}
type rabbitMQ struct {
name string // 数据源名称
}
func init() {
// 注册 RabbitMQ 插件(默认数据源)
RegisterPlugin(context.Background(), "default", MessageRabbitMQ, func() messageUtil {
return &rabbitMQ{name: "default"}
})
}
// Connect 连接 RabbitMQ
func (c *rabbitMQ) Connect(ctx context.Context) error {
return rabbitmqConnect(ctx, c.name)
}
// Ping 检测 RabbitMQ 连接状态
func (c *rabbitMQ) Ping(ctx context.Context) bool {
return rabbitmqPing(ctx, c.name)
}
// Close 关闭 RabbitMQ 连接
func (c *rabbitMQ) Close(ctx context.Context) error {
return rabbitmqClose(ctx, c.name)
}
// Publish 发布消息
func (c *rabbitMQ) Publish(ctx context.Context, msgConfig messagePublishConfig) error {
cfg, ok := msgConfig.(*RabbitMQPublishMsgConfig)
if !ok {
return fmt.Errorf("无效的 RabbitMQ 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("队列名称不能为空")
}
if cfg.Data == nil {
return fmt.Errorf("数据不能为空")
}
return c.publishMessageInternal(ctx, cfg.QueueName, cfg.Durable, 0, cfg.Data)
}
// PublishDelay 发布延迟消息
func (c *rabbitMQ) PublishDelay(ctx context.Context, msgConfig messagePublishDelayConfig) error {
cfg, ok := msgConfig.(*RabbitMQPublishDelayMsgConfig)
if !ok {
return fmt.Errorf("无效的 RabbitMQ 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("队列名称不能为空")
}
if cfg.Data == nil {
return fmt.Errorf("数据不能为空")
}
return c.publishMessageInternal(ctx, cfg.QueueName, cfg.Durable, cfg.DelayTime, cfg.Data)
}
// publishMessage 发布消息内部实现
func (c *rabbitMQ) publishMessageInternal(ctx context.Context, queueName string, durable bool, delayTime int, data interface{}) error {
if !c.Ping(ctx) {
if err := commonConnect(ctx, MessageRabbitMQ, c.name, func(ctx context.Context) error {
return c.Connect(ctx)
}, func(ctx context.Context) error {
return c.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", MessageRabbitMQ, c.name, err)
return err
}
}
channel := getRabbitMQChannel(c.name)
if channel == nil || channel.IsClosed() {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] Channel 不存在或已关闭", c.name)
return fmt.Errorf("RabbitMQ Channel 不存在或已关闭")
}
delayMsg := delayTime > 0
// 1. 决定 Exchange 类型
exchangeType := "fanout"
exchangeName := queueName
routingKey := queueName
args := amqp.Table{}
if delayMsg {
exchangeType = "x-delayed-message"
exchangeName = queueName + ".delayed"
args["x-delayed-type"] = "fanout"
}
// 2. 声明 Exchange使用 exchangeName 而不是 queueName
if err := channel.ExchangeDeclare(
exchangeName, // 修复:使用正确的交换机名称
exchangeType,
durable,
false, // autoDelete
false, // internal
false, // noWait
args,
); err != nil {
g.Log().Errorf(ctx, "❌ 声明 Exchange 失败: %v", err)
return err
}
// 3. 声明队列
if _, err := channel.QueueDeclare(
queueName,
durable,
false, // autoDelete
false, // exclusive
false, // noWait
nil, // args
); err != nil {
g.Log().Errorf(ctx, "❌ 声明队列失败: %v", err)
return err
}
// 4. 绑定队列
if err := channel.QueueBind(
queueName,
routingKey, // routingKey 路由键
exchangeName, // exchange 交换机名称
false, // noWait
nil, // args
); err != nil {
g.Log().Errorf(ctx, "❌ 绑定队列失败: %v", err)
return err
}
// 5. 序列化数据
body, err := json.Marshal(data)
if err != nil {
g.Log().Errorf(ctx, "❌ 序列化数据失败: %v", err)
return err
}
// 6. 发布消息
deliveryMode := amqp.Transient
if durable {
deliveryMode = amqp.Persistent
}
publishing := amqp.Publishing{
ContentType: "application/json",
Body: body,
DeliveryMode: deliveryMode,
Timestamp: time.Now(),
}
if delayMsg {
duration := delayTime * 1000 // 延迟时间(毫秒)= 秒 * 1000
publishing.Headers = amqp.Table{
"x-delay": duration,
}
}
err = channel.PublishWithContext(
ctx,
exchangeName,
routingKey,
false, false,
publishing,
)
if err != nil {
g.Log().Errorf(ctx, "❌ 发布消息失败: %v", err)
return err
}
g.Log().Infof(ctx, "📨 发布消息成功: queueName=%s, data=%v", queueName, data)
return err
}
// Subscribe 订阅消息
func (c *rabbitMQ) Subscribe(ctx context.Context, msgConfig messageSubscribeConfig) error {
cfg, ok := msgConfig.(*RabbitMQSubscribeMsgConfig)
if !ok {
return fmt.Errorf("无效的 RabbitMQ 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("队列名称不能为空")
}
if g.IsEmpty(cfg.ConsumerName) {
return fmt.Errorf("消费者名称不能为空")
}
if g.IsEmpty(cfg.PrefetchCount) {
cfg.PrefetchCount = 1
}
if g.IsEmpty(cfg.HandleFunc) {
return fmt.Errorf("必须提供处理函数")
}
return c.createSubscribeInternal(ctx, cfg.QueueName, cfg.ConsumerName, cfg.PrefetchCount, cfg.AutoAck, cfg.HandleFunc)
}
// createSubscribe 内部订阅消息
func (c *rabbitMQ) createSubscribeInternal(ctx context.Context, queueName, consumerName string, prefetchCount int, autoAck bool, handler func(ctx context.Context, message map[string]interface{}) error) error {
g.Log().Infof(ctx, "🔔 RabbitMQ [%s] 开始订阅: queueName=%s, consumerName=%s", c.name, queueName, consumerName)
if !c.Ping(ctx) {
if err := commonConnect(ctx, MessageRabbitMQ, c.name, func(ctx context.Context) error {
return c.Connect(ctx)
}, func(ctx context.Context) error {
return c.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", MessageRabbitMQ, c.name, err)
return err
}
}
channel := getRabbitMQChannel(c.name)
if channel == nil || channel.IsClosed() {
g.Log().Errorf(ctx, "❌ RabbitMQ [%s] Channel 不存在或已关闭", c.name)
return fmt.Errorf("RabbitMQ Channel 不存在或已关闭")
}
if err := channel.Qos(prefetchCount, 0, false); err != nil {
g.Log().Errorf(ctx, "❌ 设置 Qos 失败: %v", err)
return err
}
g.Log().Infof(ctx, "📊 设置 Prefetch Count: %d", prefetchCount)
msg, err := channel.Consume(
queueName, // queue
consumerName, // consumer
autoAck, // auto-ack (根据配置决定)
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
g.Log().Errorf(ctx, "❌ 消费消息失败: %v", err)
return err
}
g.Log().Infof(ctx, "👀 开始监听消息")
for {
select {
case <-ctx.Done():
// Context 取消,退出
g.Log().Infof(ctx, "context cancel 监听消息退出")
return nil
case m, ok := <-msg:
if !ok {
// Channel 关闭,退出
g.Log().Infof(ctx, "channel close 监听消息退出")
return nil
}
g.Log().Infof(ctx, "📨 收到消息: %s", string(m.Body))
var data map[string]interface{}
if err := json.Unmarshal(m.Body, &data); err != nil {
// 如果不是 JSON直接使用原始内容
data = map[string]interface{}{
"data": string(m.Body),
}
}
err := handler(ctx, data)
if err != nil {
g.Log().Errorf(ctx, "❌ 消息处理失败: %v", err)
// 仅在手动 ACK 模式下拒绝消息
if !autoAck {
// 拒绝消息不再重新入队(避免死循环)
m.Nack(false, false)
continue
}
}
g.Log().Infof(ctx, "✅ 消息处理成功: %v", err)
// 仅在手动 ACK 模式下确认消息
if err := m.Ack(false); err != nil {
g.Log().Errorf(ctx, "❌ AUTO ACK 消息失败: %v", err)
} else {
g.Log().Infof(ctx, "✅ AUTO ACK 消息成功")
}
}
}
}

View File

@@ -1,73 +0,0 @@
package message
import (
"context"
"fmt"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// connectFunc 连接函数类型
type connectFunc func(ctx context.Context) error
// closeFunc 关闭函数类型
type closeFunc func(ctx context.Context) error
// reconnectOption 重连选项
type reconnectOption struct {
maxRetries int // 最大重试次数0 表示无限重试
interval time.Duration // 重试间隔
componentType messageType // 组件类型nats/redis/rabbitmq
componentName string // 组件名称(数据源名称)
}
// defaultReconnectOption 默认重连选项
func defaultReconnectOption(componentType messageType, componentName string) *reconnectOption {
return &reconnectOption{
maxRetries: 0, // 无限重试
interval: 3 * time.Second,
componentType: componentType,
componentName: componentName,
}
}
// commonReconnect 重连函数NATS、Redis、RabbitMQ 共用)
func commonReconnect(ctx context.Context, connectFn connectFunc, closeFn closeFunc, opt *reconnectOption) error {
if opt == nil {
opt = defaultReconnectOption("unknown", "default")
}
for attempt := 0; opt.maxRetries == 0 || attempt < opt.maxRetries; attempt++ {
err := connectFn(ctx)
if err == nil {
g.Log().Infof(ctx, "✅ 连接成功: type=%s, name=%s, attempt=%d",
opt.componentType, opt.componentName, attempt+1)
return nil
}
// 记录失败日志
g.Log().Warningf(ctx, "⚠️ 连接失败: type=%s, name=%s, attempt=%d, err=%v, 重试中...",
opt.componentType, opt.componentName, attempt+1, err)
// 如果错误信息中包含 "does not exist",则认为是连接失败,不再重试
if strings.Contains(err.Error(), "does not exist") {
return err
}
// 等待一段时间再重试
select {
case <-time.After(opt.interval):
case <-ctx.Done():
if err = closeFn(ctx); err != nil {
return err
}
return ctx.Err()
}
}
return fmt.Errorf("连接失败,已达最大重试次数")
}
// connect 连接函数,直接调用 commonReconnect
func commonConnect(ctx context.Context, componentType messageType, name string, connectFn func(ctx context.Context) error, closeFn closeFunc) error {
opt := defaultReconnectOption(componentType, name)
return commonReconnect(ctx, connectFn, closeFn, opt)
}

View File

@@ -1,279 +0,0 @@
package message
import (
"context"
"fmt"
"github.com/gogf/gf/v2/os/glog"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
type RedisPublishMsgConfig struct {
QueueName string
Data any
}
type RedisPublishDelayMsgConfig struct {
}
type RedisSubscribeMsgConfig struct {
QueueName string
ConsumerName string
AutoAck bool
PrefetchCount int
HandleFunc func(ctx context.Context, message map[string]interface{}) error
}
func (*RedisPublishMsgConfig) GetPublishMsgType() {
}
func (*RedisPublishDelayMsgConfig) GetPublishDelayMsgType() {}
func (*RedisSubscribeMsgConfig) GetSubscribeMsgType() {
}
type redis struct {
name string // 数据源名称
}
func init() {
// 注册 Redis 插件(默认数据源)
RegisterPlugin(context.Background(), "default", MessageRedis, func() messageUtil {
return &redis{name: "default"}
})
}
// RedisStreamMessage Redis Stream 消息结构
type redisStreamMessage struct {
ID string
Values map[string]interface{}
}
// Connect 连接 Redis
func (c *redis) Connect(ctx context.Context) error {
return redisConnect(ctx, c.name)
}
// Ping 检测 Redis 连接状态
func (c *redis) Ping(ctx context.Context) bool {
return redisPing(ctx, c.name)
}
// Close 关闭 Redis 连接
func (c *redis) Close(ctx context.Context) error {
return redisClose(ctx, c.name)
}
// Publish 发布消息
func (c *redis) Publish(ctx context.Context, msgConfig messagePublishConfig) error {
cfg, ok := msgConfig.(*RedisPublishMsgConfig)
if !ok {
return fmt.Errorf("无效的 Redis 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("队列名称不能为空")
}
if g.IsEmpty(cfg.Data) {
return fmt.Errorf("数据不能为空")
}
rc := getRedisConn(c.name)
if !c.Ping(ctx) {
if err := commonConnect(ctx, MessageRedis, c.name, func(ctx context.Context) error {
return c.Connect(ctx)
}, func(ctx context.Context) error {
return c.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", MessageRedis, c.name, err)
return err
}
}
values := gconv.Map(cfg.Data)
args := make([]interface{}, 0, len(values)*2+2)
args = append(args, cfg.QueueName, "*")
for key, val := range values {
args = append(args, key, val)
}
result, err := rc.Do(ctx, "XADD", args...)
if err != nil {
g.Log().Errorf(ctx, "❌ Redis 发布消息失败: key=%s, err=%v", cfg.QueueName, err)
return err
}
g.Log().Infof(ctx, "✅ Redis 发布消息成功: key=%s, messageID=%s", cfg.QueueName, gconv.String(result))
return nil
}
// PublishDelay 发布延迟消息
func (c *redis) PublishDelay(ctx context.Context, _ messagePublishDelayConfig) error {
g.Log().Errorf(ctx, "❌ Redis 不支持延迟消息")
return fmt.Errorf("❌ Redis 不支持延迟消息")
}
// Subscribe 订阅消息
func (c *redis) Subscribe(ctx context.Context, msgConfig messageSubscribeConfig) error {
cfg, ok := msgConfig.(*RedisSubscribeMsgConfig)
if !ok {
return fmt.Errorf("无效的 Redis 配置类型")
}
if g.IsEmpty(cfg.QueueName) {
return fmt.Errorf("队列名称不能为空")
}
if g.IsEmpty(cfg.ConsumerName) {
return fmt.Errorf("消费者名称不能为空")
}
if g.IsEmpty(cfg.HandleFunc) {
return fmt.Errorf("处理函数不能为空")
}
return c.createSubscribe(ctx, cfg.QueueName, cfg.ConsumerName, cfg.PrefetchCount, cfg.AutoAck, cfg.HandleFunc)
}
// createSubscribe 内部订阅消息
func (c *redis) createSubscribe(ctx context.Context, key, consumerName string, prefetchCount int, autoAck bool, handler func(ctx context.Context, message map[string]interface{}) error) error {
LOOP:
err := c.consumeMessages(ctx, key, consumerName, prefetchCount, autoAck, handler)
if err != nil {
// 对于超时错误,返回nil继续循环,而不是返回错误
if strings.Contains(err.Error(), "i/o timeout") || strings.Contains(err.Error(), "timeout") ||
strings.Contains(err.Error(), "context deadline exceeded") || strings.Contains(err.Error(), "context canceled") {
time.Sleep(time.Second)
goto LOOP
} else {
g.Log().Errorf(ctx, "❌ 严重错误: %v", err)
}
}
time.Sleep(time.Second)
goto LOOP
}
// consumeMessages 消费消息
func (c *redis) consumeMessages(ctx context.Context, key, consumerName string, prefetchCount int, autoAck bool, handler func(ctx context.Context, message map[string]interface{}) error) error {
if !c.Ping(ctx) {
if err := commonConnect(ctx, MessageRedis, c.name, func(ctx context.Context) error {
return c.Connect(ctx)
}, func(ctx context.Context) error {
return c.Close(ctx)
}); err != nil {
g.Log().Errorf(ctx, "❌ [%s][%s] 连接失败: %v", MessageRedis, c.name, err)
return err
}
}
rc := getRedisConn(c.name)
if rc == nil {
g.Log().Errorf(ctx, "❌ Redis [%s] 连接不存在", c.name)
return fmt.Errorf("Redis 连接不存在")
}
// 检查消费者组是否存在
groupName := "default"
_, err := rc.Do(ctx, "XGROUP", "CREATE", key, groupName, "0", "MKSTREAM")
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "BUSYGROUP") && strings.Contains(errStr, "already exists") {
glog.Infof(ctx, "✅ Redis [%s] 消费者组已存在: %s", c.name, key)
return nil
}
g.Log().Errorf(ctx, "❌ 创建消费组失败: key=%s, err=%v", key, err)
return err
}
glog.Infof(ctx, "✅ Redis [%s] 消费者组创建成功: %s", c.name, key)
// 使用带重试的命令执行
result, err := rc.Do(ctx, "XREADGROUP", "GROUP", groupName, consumerName, "COUNT", prefetchCount, "BLOCK", 0, "STREAMS", key, ">")
if err != nil {
return err
}
messages, err := c.parseStreamResult(result)
if err != nil {
g.Log().Errorf(ctx, "❌ 解析消息失败: %v", err)
return err
}
for _, msg := range messages {
// 处理消息
if err := handler(ctx, msg.Values); err != nil {
g.Log().Errorf(ctx, "❌ 消息处理失败: messageID=%s, err=%v", msg.ID, err)
// 如果不是自动ACK,则跳过当前消息
if !autoAck {
continue
}
} else {
g.Log().Infof(ctx, "✅ 消息处理成功: messageID=%s", msg.ID)
}
// ACK 消息
args := make([]interface{}, 0, len(msg.ID)+2)
args = append(args, key, groupName, msg.ID)
_, err = rc.Do(ctx, "XACK", args...)
if err != nil {
g.Log().Errorf(ctx, "❌ ACK 消息失败: messageID=%s, err=%v", msg.ID, err)
} else {
g.Log().Infof(ctx, "✅ ACK 消息成功: messageID=%s", msg.ID)
}
}
return nil
}
// parseStreamResult 解析 Stream 结果
func (c *redis) parseStreamResult(result interface{}) ([]redisStreamMessage, error) {
if result == nil {
return []redisStreamMessage{}, nil
}
var resultVal interface{}
// 尝试获取 Val() 方法
if valuer, ok := result.(interface{ Val() interface{} }); ok {
resultVal = valuer.Val()
} else {
resultVal = result
}
// 检查是否为空
if resultVal == nil {
return []redisStreamMessage{}, nil
}
// 预分配切片容量,避免多次扩容
messages := make([]redisStreamMessage, 0)
if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok {
for _, streamData := range streamsMap {
msgArray, ok := streamData.([]interface{})
if !ok {
continue
}
for _, msgData := range msgArray {
msgArray, ok := msgData.([]interface{})
if !ok || len(msgArray) < 2 {
continue
}
msgID := gconv.String(msgArray[0])
fieldsArray, ok := msgArray[1].([]interface{})
if !ok {
continue
}
values := make(map[string]interface{}, len(fieldsArray)/2)
for i := 0; i < len(fieldsArray); i += 2 {
if i+1 < len(fieldsArray) {
key := gconv.String(fieldsArray[i])
values[key] = fieldsArray[i+1]
}
}
messages = append(messages, redisStreamMessage{
ID: msgID,
Values: values,
})
}
}
}
return messages, nil
}

View File

@@ -1,125 +0,0 @@
// Copyright 2019-2025 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import "fmt"
type RetentionPolicy int
const (
// LimitsPolicy (default) means that messages are retained until any given limit is reached.
// This could be one of MaxMsgs, MaxBytes, or MaxAge.
LimitsPolicy RetentionPolicy = iota
// InterestPolicy specifies that when all known consumers have acknowledged a message it can be removed.
InterestPolicy
// WorkQueuePolicy specifies that when the first worker or subscriber acknowledges the message it can be removed.
WorkQueuePolicy
)
// MarshalJSON 将 RetentionPolicy 序列化为字符串
func (rp RetentionPolicy) MarshalJSON() ([]byte, error) {
switch rp {
case LimitsPolicy:
return []byte(`"limits"`), nil
case InterestPolicy:
return []byte(`"interest"`), nil
case WorkQueuePolicy:
return []byte(`"workqueue"`), nil
default:
return nil, fmt.Errorf("can not marshal %v", rp)
}
}
// UnmarshalJSON 将字符串反序列化为 RetentionPolicy
func (rp *RetentionPolicy) UnmarshalJSON(data []byte) error {
switch string(data) {
case `"limits"`:
*rp = LimitsPolicy
case `"interest"`:
*rp = InterestPolicy
case `"workqueue"`:
*rp = WorkQueuePolicy
default:
return fmt.Errorf("unknown retention policy: %s", string(data))
}
return nil
}
type DiscardPolicy int
const (
// DiscardOld will remove older messages to return to the limits.
DiscardOld = iota
// DiscardNew will error on a StoreMsg call
DiscardNew
)
// MarshalJSON 将 DiscardPolicy 序列化为字符串
func (dp DiscardPolicy) MarshalJSON() ([]byte, error) {
switch dp {
case DiscardOld:
return []byte(`"old"`), nil
case DiscardNew:
return []byte(`"new"`), nil
default:
return nil, fmt.Errorf("can not marshal %v", dp)
}
}
// UnmarshalJSON 将字符串反序列化为 DiscardPolicy
func (dp *DiscardPolicy) UnmarshalJSON(data []byte) error {
switch string(data) {
case `"old"`:
*dp = DiscardOld
case `"new"`:
*dp = DiscardNew
default:
return fmt.Errorf("unknown discard policy: %s", string(data))
}
return nil
}
type StorageType int
const (
// FileStorage specifies on disk, designated by the JetStream config StoreDir.
FileStorage = StorageType(22)
// MemoryStorage specifies in memory only.
MemoryStorage = StorageType(33)
)
// MarshalJSON 将 StorageType 序列化为字符串
func (st StorageType) MarshalJSON() ([]byte, error) {
switch st {
case MemoryStorage:
return []byte(`"memory"`), nil
case FileStorage:
return []byte(`"file"`), nil
default:
return nil, fmt.Errorf("can not marshal %v", st)
}
}
// UnmarshalJSON 将字符串反序列化为 StorageType
func (st *StorageType) UnmarshalJSON(data []byte) error {
switch string(data) {
case `"memory"`:
*st = MemoryStorage
case `"file"`:
*st = FileStorage
default:
return fmt.Errorf("unknown storage type: %s", string(data))
}
return nil
}

View File

@@ -1,212 +0,0 @@
// Copyright 2019-2026 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"fmt"
"time"
)
// StreamConfig will determine the name, subjects and retention policy
// for a given stream. If subjects is empty the name will be used.
type StreamConfig struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Subjects []string `json:"subjects,omitempty"`
Retention RetentionPolicy `json:"retention"`
MaxConsumers int `json:"max_consumers"`
MaxMsgs int64 `json:"max_msgs"`
MaxBytes int64 `json:"max_bytes"`
MaxAge time.Duration `json:"max_age"`
MaxMsgsPer int64 `json:"max_msgs_per_subject"`
MaxMsgSize int32 `json:"max_msg_size,omitempty"`
Discard DiscardPolicy `json:"discard"`
Storage StorageType `json:"storage"`
Replicas int `json:"num_replicas"`
NoAck bool `json:"no_ack,omitempty"`
Duplicates time.Duration `json:"duplicate_window,omitempty"`
Placement *Placement `json:"placement,omitempty"`
Mirror *StreamSource `json:"mirror,omitempty"`
Sources []*StreamSource `json:"sources,omitempty"`
Compression StoreCompression `json:"compression"`
FirstSeq uint64 `json:"first_seq,omitempty"`
// Allow applying a subject transform to incoming messages before doing anything else
SubjectTransform *SubjectTransformConfig `json:"subject_transform,omitempty"`
// Allow republish of the message after being sequenced and stored.
RePublish *RePublish `json:"republish,omitempty"`
// Allow higher performance, direct access to get individual messages. E.g. KeyValue
AllowDirect bool `json:"allow_direct"`
// Allow higher performance and unified direct access for mirrors as well.
MirrorDirect bool `json:"mirror_direct"`
// Allow KV like semantics to also discard new on a per subject basis
DiscardNewPer bool `json:"discard_new_per_subject,omitempty"`
// Optional qualifiers. These can not be modified after set to true.
// Sealed will seal a stream so no messages can get out or in.
Sealed bool `json:"sealed"`
// DenyDelete will restrict the ability to delete messages.
DenyDelete bool `json:"deny_delete"`
// DenyPurge will restrict the ability to purge messages.
DenyPurge bool `json:"deny_purge"`
// AllowRollup allows messages to be placed into the system and purge
// all older messages using a special msg header.
AllowRollup bool `json:"allow_rollup_hdrs"`
// The following defaults will apply to consumers when created against
// this stream, unless overridden manually.
// TODO(nat): Can/should we name these better?
ConsumerLimits StreamConsumerLimits `json:"consumer_limits"`
// AllowMsgTTL allows header initiated per-message TTLs. If disabled,
// then the `NATS-TTL` header will be ignored.
AllowMsgTTL bool `json:"allow_msg_ttl"`
// SubjectDeleteMarkerTTL sets the TTL of delete marker messages left behind by
// subject delete markers.
SubjectDeleteMarkerTTL time.Duration `json:"subject_delete_marker_ttl,omitempty"`
// AllowMsgCounter allows a stream to use (only) counter CRDTs.
AllowMsgCounter bool `json:"allow_msg_counter,omitempty"`
// AllowAtomicPublish allows atomic batch publishing into the stream.
AllowAtomicPublish bool `json:"allow_atomic,omitempty"`
// AllowMsgSchedules allows the scheduling of messages.
AllowMsgSchedules bool `json:"allow_msg_schedules,omitempty"`
// PersistMode allows to opt-in to different persistence mode settings.
PersistMode PersistModeType `json:"persist_mode,omitempty"`
// Metadata is additional metadata for the Stream.
Metadata map[string]string `json:"metadata,omitempty"`
}
// Used to guide placement of streams and meta controllers in clustered JetStream.
type Placement struct {
Cluster string `json:"cluster,omitempty"`
Tags []string `json:"tags,omitempty"`
Preferred string `json:"preferred,omitempty"`
}
// StreamSource dictates how streams can source from other streams.
type StreamSource struct {
Name string `json:"name"`
OptStartSeq uint64 `json:"opt_start_seq,omitempty"`
OptStartTime *time.Time `json:"opt_start_time,omitempty"`
FilterSubject string `json:"filter_subject,omitempty"`
SubjectTransforms []SubjectTransformConfig `json:"subject_transforms,omitempty"`
External *ExternalStream `json:"external,omitempty"`
// Internal
iname string // For indexing when stream names are the same for multiple sources.
}
// SubjectTransformConfig is for applying a subject transform (to matching messages) before doing anything else when a new message is received
type SubjectTransformConfig struct {
Source string `json:"src"`
Destination string `json:"dest"`
}
// ExternalStream allows you to qualify access to a stream source in another account or domain.
type ExternalStream struct {
ApiPrefix string `json:"api"`
DeliverPrefix string `json:"deliver"`
}
// RePublish is for republishing messages once committed to a stream.
type RePublish struct {
Source string `json:"src,omitempty"`
Destination string `json:"dest"`
HeadersOnly bool `json:"headers_only,omitempty"`
}
type StreamConsumerLimits struct {
InactiveThreshold time.Duration `json:"inactive_threshold,omitempty"`
MaxAckPending int `json:"max_ack_pending,omitempty"`
}
// PersistModeType determines what persistence mode the stream uses.
type PersistModeType int
const (
// DefaultPersistMode specifies the default persist mode. Writes to the stream will immediately be flushed.
// The publish acknowledgement will be sent after the persisting completes.
DefaultPersistMode = PersistModeType(iota)
// AsyncPersistMode specifies writes to the stream will be flushed asynchronously.
// The publish acknowledgement may be sent before the persisting completes.
// This means writes could be lost if they weren't flushed prior to a hard kill of the server.
AsyncPersistMode
)
// MarshalJSON 将 PersistModeType 序列化为字符串
func (pm PersistModeType) MarshalJSON() ([]byte, error) {
switch pm {
case DefaultPersistMode:
return []byte(`"default"`), nil
case AsyncPersistMode:
return []byte(`"async"`), nil
default:
return nil, fmt.Errorf("can not marshal %v", pm)
}
}
// UnmarshalJSON 将字符串反序列化为 PersistModeType
func (pm *PersistModeType) UnmarshalJSON(data []byte) error {
switch string(data) {
case `"default"`:
*pm = DefaultPersistMode
case `"async"`:
*pm = AsyncPersistMode
default:
return fmt.Errorf("unknown persist mode: %s", string(data))
}
return nil
}
type StoreCompression uint8
const (
NoCompression StoreCompression = iota
S2Compression
)
// MarshalJSON 将 StoreCompression 序列化为字符串
func (sc StoreCompression) MarshalJSON() ([]byte, error) {
switch sc {
case NoCompression:
return []byte(`"none"`), nil
case S2Compression:
return []byte(`"s2"`), nil
default:
return nil, fmt.Errorf("can not marshal %v", sc)
}
}
// UnmarshalJSON 将字符串反序列化为 StoreCompression
func (sc *StoreCompression) UnmarshalJSON(data []byte) error {
switch string(data) {
case `"none"`:
*sc = NoCompression
case `"s2"`:
*sc = S2Compression
default:
return fmt.Errorf("unknown store compression: %s", string(data))
}
return nil
}

View File

@@ -1,25 +1,48 @@
package middleware
import (
"context"
"fmt"
"strings"
"gitea.com/red-future/common/redis"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
"github.com/gogf/gf/v2/util/gconv"
)
// 限流 Redis Key 常量
const (
RateLimitKeyPrefix = "ragflow:ratelimit:" // 限流Key前缀
RateLimitKeyIP = "ip:%s" // IP限流: ip:192.168.1.1
RateLimitKeyUser = "user:%s" // 用户限流: user:123 或 user:anon:192.168.1.1
RateLimitKeyService = "service:%s" // 服务限流: service:customerService
RateLimitKeyGlobal = "global:requests" // 全局限流: global:requests
)
func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) {
fullKey := RateLimitKeyPrefix + key
count, err = g.Redis().Incr(ctx, fullKey)
if err != nil {
return
}
// 首次设置过期时间
if count == 1 {
g.Redis().Expire(ctx, fullKey, windowSeconds)
}
return
}
// GlobalLimiter 全局限流中间件使用Redis分布式控制
func GlobalLimiter(r *ghttp.Request) {
// 从配置文件读取全局限流参数
globalLimit := g.Cfg().MustGet(r.GetCtx(), "rate.limit", 800).Int64()
key := redis.RateLimitKeyGlobal
key := RateLimitKeyGlobal
// 使用Redis计数器进行全局限流
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
count, err := IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "全局限流Redis错误: %v", err)
r.Middleware.Next()
@@ -38,13 +61,13 @@ func GlobalLimiter(r *ghttp.Request) {
// IPLimiter IP限流中间件防DDoS
func IPLimiter(r *ghttp.Request) {
ip := r.GetClientIp()
key := fmt.Sprintf(redis.RateLimitKeyIP, ip)
key := fmt.Sprintf(RateLimitKeyIP, ip)
// 从配置文件读取IP限流参数
ipLimit := g.Cfg().MustGet(r.GetCtx(), "rate.ip.limit", 100).Int64()
// 使用Redis计数器
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
count, err := IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "IP限流Redis错误: %v", err)
r.Middleware.Next()
@@ -75,8 +98,8 @@ func UserLimiter(r *ghttp.Request) {
userName = gconv.String(user.UserName)
// 从配置文件读取用户限流参数
userLimit := g.Cfg().MustGet(r.GetCtx(), "rate.user.limit", 50).Int64()
key := fmt.Sprintf(redis.RateLimitKeyUser, userName)
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1)
key := fmt.Sprintf(RateLimitKeyUser, userName)
count, err := IncrRateLimit(r.GetCtx(), key, 1)
if err != nil {
g.Log().Errorf(r.GetCtx(), "用户限流Redis错误: %v", err)
return
@@ -111,8 +134,8 @@ func ServiceLimiter(r *ghttp.Request) {
return
}
key := fmt.Sprintf(redis.RateLimitKeyService, serverName)
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1)
key := fmt.Sprintf(RateLimitKeyService, serverName)
count, err := IncrRateLimit(r.GetCtx(), key, 1)
if err != nil {
g.Log().Errorf(r.GetCtx(), "服务限流Redis错误: %v", err)
r.Middleware.Next()

View File

@@ -1,129 +0,0 @@
package minio
import (
"context"
"fmt"
"net/http"
"path/filepath"
"strings"
"time"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
"github.com/gogf/gf/v2/os/glog"
"github.com/google/uuid"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
// IoConfig 映射 YAML 中的 minio 配置节点
type IoConfig struct {
Endpoint string `yaml:"endpoint"` // MinIO API 地址
AccessKey string `yaml:"accessKey"` // AK
SecretKey string `yaml:"secretKey"` // SK
Secure bool `yaml:"secure"` // 是否启用 SSL
Region string `yaml:"region"` // 区域
}
// 全局 MinIO 客户端(初始化一次,避免重复创建)
var minioClient *minio.Client
var minioCfg IoConfig
// initMinIO 初始化 MinIO 客户端。
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "minio").IsEmpty() {
// 加载 MinIO 配置(可从配置文件/环境变量读取,这里硬编码示例)
minioCfg = IoConfig{
Endpoint: g.Cfg().MustGet(ctx, "minio.endpoint").String(),
AccessKey: g.Cfg().MustGet(ctx, "minio.accessKey").String(),
SecretKey: g.Cfg().MustGet(ctx, "minio.secretKey").String(),
Secure: g.Cfg().MustGet(ctx, "minio.secure").Bool(),
Region: g.Cfg().MustGet(ctx, "minio.region").String(),
}
// 创建 MinIO 客户端
var err error
if minioClient, err = minio.New(minioCfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(minioCfg.AccessKey, minioCfg.SecretKey, ""),
Secure: minioCfg.Secure,
Region: minioCfg.Region,
}); err != nil {
glog.Errorf(ctx, "初始化 MinIO 客户端失败: %v", err)
}
}
}
func UploadFile(ctx context.Context, fileHeader *ghttp.UploadFile) (imagesUrl string, fileName string, fileFormat string, err error) {
return uploadFile(ctx, fileHeader)
}
func uploadFile(ctx context.Context, fileHeader *ghttp.UploadFile) (imagesUrl string, fileName string, fileFormat string, err error) {
bucketName, err := utils.GetBucketName(ctx)
if err != nil {
glog.Errorf(ctx, "获取桶名称失败: %v", err)
return
}
// 检查/创建桶
exists, err := minioClient.BucketExists(ctx, bucketName)
if err != nil {
glog.Errorf(ctx, "检查桶是否存在失败: %v", err)
return
}
if !exists {
if err = minioClient.MakeBucket(ctx, bucketName, minio.MakeBucketOptions{Region: minioCfg.Region}); err != nil {
glog.Errorf(ctx, "创建桶失败: %v", err)
return
}
glog.Infof(ctx, "成功创建 MinIO 桶: %s", bucketName)
}
// 打开文件,获取 io.Reader*os.File 实现了 io.Reader
file, err := fileHeader.Open()
if err != nil {
glog.Errorf(ctx, "打开文件失败: %v", err)
return
}
defer file.Close() // 必须关闭,避免文件句柄泄露
// 获取文件类型
buffer := make([]byte, 512)
_, err = file.Read(buffer)
if err != nil {
glog.Errorf(ctx, "读取文件头失败: %v", err)
return
}
contentType := http.DetectContentType(buffer)
// 重置文件读取位置,否则后续 PutObject 会从第512字节开始上传
if _, err = file.Seek(0, 0); err != nil {
glog.Errorf(ctx, "重置文件读取位置失败: %v", err)
return
}
// 生成唯一的 MinIO 对象名(避免覆盖)
fileExt := filepath.Ext(fileHeader.Filename) // 原文件后缀(如 .jpg
uniqueID := uuid.New().String()[:32] // 32位随机UUID
timestamp := time.Now().Format("2006-01-02") // 日期目录(便于管理)
objectName := fmt.Sprintf("/%s/%s%s", timestamp, uniqueID, fileExt) // 存储路径20251209/abc12345.jpg
// 设置存储桶公共读权限
policy := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Principal":{"AWS":["*"]},"Action":["s3:GetObject"],"Resource":["arn:aws:s3:::` + bucketName + `/*"]}]}`
if err = minioClient.SetBucketPolicy(ctx, bucketName, policy); err != nil {
glog.Errorf(ctx, "设置存储桶权限失败: %v", err)
return
}
// 执行图片上传
_, err = minioClient.PutObject(
ctx,
bucketName,
objectName,
file,
fileHeader.Size,
minio.PutObjectOptions{
ContentType: contentType, // 关键指定图片MIME类型S3会根据此类型处理
// 若需要图片可公开访问,添加如下配置(根据需求选择)
//ACL: minio.ACLPublicRead,
},
)
if err != nil {
glog.Errorf(ctx, "上传图片失败: %v", err)
return
}
return objectName, fileHeader.Filename, strings.ReplaceAll(fileExt, ".", ""), err
}

View File

@@ -1,216 +0,0 @@
package rabbitmq
import (
"context"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
amqp "github.com/rabbitmq/amqp091-go"
)
// MessageHandler 消息处理函数
type MessageHandler func(ctx context.Context, body []byte) error
// Consumer 消费者
type Consumer struct {
queue string
consumerTag string
prefetchCount int // QoS: 预取数量(并发控制)
autoAck bool // 是否自动确认
handler MessageHandler
workerCount int // worker 数量
cancel context.CancelFunc // 用于停止 worker
channel *amqp.Channel // 独立Channel避免并发冲突
}
// ConsumerOption 消费者配置选项
type ConsumerOption func(*Consumer)
// WithPrefetchCount 设置预取数量(并发控制)
func WithPrefetchCount(count int) ConsumerOption {
return func(c *Consumer) {
c.prefetchCount = count
}
}
// WithAutoAck 设置自动确认
func WithAutoAck(autoAck bool) ConsumerOption {
return func(c *Consumer) {
c.autoAck = autoAck
}
}
// WithWorkerCount 设置 worker 数量
func WithWorkerCount(count int) ConsumerOption {
return func(c *Consumer) {
c.workerCount = count
}
}
// WithConsumerTag 设置消费者标签
func WithConsumerTag(tag string) ConsumerOption {
return func(c *Consumer) {
c.consumerTag = tag
}
}
// NewConsumer 创建消费者
func NewConsumer(queue string, handler MessageHandler, opts ...ConsumerOption) *Consumer {
c := &Consumer{
queue: queue,
consumerTag: "",
prefetchCount: 1, // 默认 1 个
autoAck: false, // 默认手动确认
handler: handler,
workerCount: 1, // 默认 1 个 worker
}
// 应用选项
for _, opt := range opts {
opt(c)
}
return c
}
// Start 启动消费者
func (c *Consumer) Start(ctx context.Context) (err error) {
// 创建可取消的 context
workerCtx, cancel := context.WithCancel(ctx)
c.cancel = cancel
// 为每个消费者创建独立Channel避免并发冲突
conn, err := GetConnection()
if err != nil {
return gerror.Wrap(err, "获取RabbitMQ连接失败")
}
c.channel, err = conn.Channel()
if err != nil {
return gerror.Wrap(err, "创建独立Channel失败")
}
ch := c.channel
// 声明队列(如果不存在则创建)
// 注意Queue到Exchange的绑定应由message服务在发送响应时动态创建或通过运维工具提前配置
_, err = ch.QueueDeclare(
c.queue, // name
true, // durable持久化
false, // autoDelete不自动删除
false, // exclusive非独占
false, // noWait
nil, // arguments
)
if err != nil {
return gerror.Newf("声明队列失败: %v", err)
}
// 设置 QoS并发控制
err = ch.Qos(
c.prefetchCount, // prefetchCount: 每个 consumer 最多同时处理的消息数
0, // prefetchSize: 0 表示不限制
false, // global: false 表示仅应用于当前 channel
)
if err != nil {
return gerror.Newf("设置 QoS 失败: %v", err)
}
// 开始消费
msgs, err := ch.Consume(
c.queue, // queue
c.consumerTag, // consumer tag
c.autoAck, // auto-ack
false, // exclusive
false, // no-local
false, // no-wait
nil, // args
)
if err != nil {
return gerror.Newf("开始消费失败: %v", err)
}
g.Log().Infof(ctx, "消费者已启动: queue=%s, prefetch=%d, workers=%d",
c.queue, c.prefetchCount, c.workerCount)
// 启动多个 worker
for i := 0; i < c.workerCount; i++ {
go c.worker(workerCtx, i, msgs)
}
return
}
// worker 工作协程
func (c *Consumer) worker(ctx context.Context, workerID int, msgs <-chan amqp.Delivery) {
g.Log().Debugf(ctx, "Worker %d 已启动", workerID)
for {
select {
case <-ctx.Done():
// Context 取消,退出
g.Log().Infof(ctx, "Worker %d 收到停止信号,正在退出", workerID)
return
case msg, ok := <-msgs:
if !ok {
// Channel 关闭,退出
g.Log().Infof(ctx, "Worker %d 消息通道已关闭,退出", workerID)
return
}
// 处理消息
err := c.handler(ctx, msg.Body)
if err != nil {
g.Log().Errorf(ctx, "Worker %d 处理消息失败: %v", workerID, err)
// 如果不是自动确认,需要手动 Nack
if !c.autoAck {
// requeue=false: 不重新入队,进入死信队列
msg.Nack(false, false)
}
} else {
// 处理成功,手动确认
if !c.autoAck {
msg.Ack(false)
}
g.Log().Debugf(ctx, "Worker %d 处理消息成功", workerID)
}
}
}
}
// StartTypedConsumer 启动类型化消费者(自动反序列化)
func StartTypedConsumer[T any](
ctx context.Context,
queue string,
handler func(ctx context.Context, msg *T) error,
opts ...ConsumerOption,
) error {
// 包装处理函数
wrappedHandler := func(ctx context.Context, body []byte) error {
var msg T
if err := gjson.DecodeTo(body, &msg); err != nil {
return gerror.Newf("反序列化消息失败: %v", err)
}
return handler(ctx, &msg)
}
consumer := NewConsumer(queue, wrappedHandler, opts...)
return consumer.Start(ctx)
}
// Stop 停止消费者
func (c *Consumer) Stop(ctx context.Context) {
if c.cancel != nil {
c.cancel()
}
// 关闭独立Channel
if c.channel != nil && !c.channel.IsClosed() {
c.channel.Close()
g.Log().Debugf(ctx, "消费者Channel已关闭: queue=%s", c.queue)
}
g.Log().Infof(ctx, "正在停止消费者: queue=%s", c.queue)
c.cancel = nil
}

View File

@@ -1,175 +0,0 @@
// Package rabbitmq 提供 RabbitMQ 消费者管理功能
//
// 本文件实现消费者统一管理,简化业务层的启动逻辑
package rabbitmq
import (
"context"
"sync"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
)
// ManagedConsumer 托管消费者(包含启动和停止函数)
type ManagedConsumer struct {
Name string // 消费者名称
Start func(ctx context.Context) error // 启动函数
Stop func(ctx context.Context) // 停止函数
}
// ConsumerManager RabbitMQ 消费者管理器
//
// 职责:
// 1. 统一管理所有 RabbitMQ 消费者的生命周期
// 2. 初始化 RabbitMQ 连接和队列
// 3. 启动/停止所有消费者
// 4. 协调消费者的优雅退出
//
// 使用示例:
//
// mgr := rabbitmq.NewConsumerManager(ctx)
// mgr.Register("响应消费者", responseConsumer.Start, responseConsumer.Stop)
// mgr.Init()
// defer mgr.Stop()
type ConsumerManager struct {
ctx context.Context // 全局上下文
consumers []*ManagedConsumer // 消费者列表
wg sync.WaitGroup // 等待所有消费者协程退出
}
// NewConsumerManager 创建消费者管理器
//
// 参数:
//
// ctx: 上下文
//
// 返回:
//
// *ConsumerManager: 消费者管理器实例
func NewConsumerManager(ctx context.Context) *ConsumerManager {
return &ConsumerManager{
ctx: ctx,
consumers: make([]*ManagedConsumer, 0),
}
}
// Register 注册消费者
//
// 参数:
//
// name: 消费者名称(用于日志)
// startFunc: 启动函数
// stopFunc: 停止函数
//
// 使用示例:
//
// consumer := service.NewResponseConsumer(ctx)
// mgr.Register("响应消费者", consumer.Start, consumer.Stop)
func (cm *ConsumerManager) Register(name string, startFunc func(ctx context.Context) error, stopFunc func(ctx context.Context)) {
cm.consumers = append(cm.consumers, &ManagedConsumer{
Name: name,
Start: startFunc,
Stop: stopFunc,
})
}
// Init 初始化并启动所有消费者
//
// 执行流程:
// 1. 检查 RabbitMQ 配置(未配置则跳过)
// 2. 初始化 RabbitMQ 连接
// 3. 声明并绑定队列(响应队列、延时落库队列)
// 4. 异步启动所有已注册的消费者
//
// 返回:
//
// err: 错误信息,成功返回 nil
//
// 注意:
// - 如果 RabbitMQ 未配置,不会报错,只是跳过初始化
// - 响应队列初始化失败会导致 Fatal 退出
// - 延时落库队列失败只会 Warning不影响主流程
func (cm *ConsumerManager) Init() (err error) {
// 检查配置文件中是否配置了 RabbitMQ
if g.Cfg().MustGet(cm.ctx, "rabbitmq").IsEmpty() {
glog.Info(cm.ctx, "RabbitMQ未配置跳过消费者初始化")
return
}
// 初始化 RabbitMQ 连接(从 config.yml 读取配置)
if err = InitFromConfig(cm.ctx); err != nil {
glog.Fatalf(cm.ctx, "初始化 RabbitMQ 失败: %v", err)
return
}
glog.Info(cm.ctx, "RabbitMQ 连接已初始化")
// 声明响应Exchange队列由各消费者自己声明和绑定
if err = DeclareExchange(cm.ctx, &ExchangeConfig{
Name: "ragflow.response",
Type: "topic",
Durable: true,
}); err != nil {
glog.Fatalf(cm.ctx, "声明响应Exchange失败: %v", err)
return
}
// 设置延时落库队列(对话缓存兜底机制)
// 失败不影响主流程,只记录 Warning
if err = SetupDelayedFlushQueue(cm.ctx); err != nil {
glog.Warningf(cm.ctx, "设置延时落库队列失败: %v", err)
}
// 异步启动所有已注册的消费者
cm.startConsumers()
return
}
// startConsumers 启动所有消费者(内部方法)
//
// 实现:
// 1. 遍历已注册的消费者
// 2. 每个消费者在独立的 goroutine 中运行
// 3. 使用 WaitGroup 追踪所有消费者协程
func (cm *ConsumerManager) startConsumers() {
for _, c := range cm.consumers {
cm.wg.Add(1)
go func(consumer *ManagedConsumer) {
defer cm.wg.Done()
if err := consumer.Start(cm.ctx); err != nil {
glog.Errorf(cm.ctx, "%s启动失败: %v", consumer.Name, err)
}
}(c)
glog.Infof(cm.ctx, "%s已启动", c.Name)
}
}
// Stop 停止所有消费者(优雅退出)
//
// 执行流程:
// 1. 依次停止所有消费者(调用各自的 Stop 方法)
// 2. 等待所有消费者协程退出WaitGroup.Wait
// 3. 关闭 RabbitMQ 连接
//
// 使用场景:
// - 收到 SIGINT/SIGTERM 信号时
// - 程序正常退出时
// - defer mgr.Stop()
//
// 注意:
// - Stop 方法会阻塞直到所有消费者完全退出
// - 确保消费者能正确响应 Stop 信号
func (cm *ConsumerManager) Stop() {
// 依次停止所有消费者
for _, c := range cm.consumers {
c.Stop(cm.ctx)
glog.Infof(cm.ctx, "%s已停止", c.Name)
}
// 等待所有消费者协程退出
cm.wg.Wait()
// 关闭 RabbitMQ 连接
Close(cm.ctx)
glog.Info(cm.ctx, "所有消费者已停止RabbitMQ连接已关闭")
}

View File

@@ -1,95 +0,0 @@
// Package rabbitmq - RabbitMQ延时消息发布
package rabbitmq
import (
"context"
"time"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gerror"
amqp "github.com/rabbitmq/amqp091-go"
)
// PublishWithDelay 发布延时消息到RabbitMQ
// delaySeconds: 延时秒数
func PublishWithDelay(ctx context.Context, routingKey string, message interface{}, delaySeconds int) error {
ch, err := GetChannel()
if err != nil {
return gerror.Wrap(err, "获取RabbitMQ通道失败")
}
if ch == nil {
return gerror.New("RabbitMQ通道未初始化")
}
// 序列化消息
body, err := gjson.Encode(message)
if err != nil {
return gerror.Wrapf(err, "序列化消息失败")
}
// 声明延时交换机x-delayed-message类型
// 注意需要RabbitMQ安装延时插件 rabbitmq-plugins enable rabbitmq_delayed_message_exchange
exchangeName := "delayed.exchange"
err = ch.ExchangeDeclare(
exchangeName,
"x-delayed-message", // 延时交换机类型
true, // durable
false, // auto-deleted
false, // internal
false, // no-wait
amqp.Table{
"x-delayed-type": "direct", // 底层交换机类型
},
)
if err != nil {
return gerror.Wrapf(err, "声明延时交换机失败")
}
// 声明队列
queue, err := ch.QueueDeclare(
routingKey, // 队列名使用routingKey
true, // durable
false, // delete when unused
false, // exclusive
false, // no-wait
nil,
)
if err != nil {
return gerror.Wrapf(err, "声明队列失败")
}
// 绑定队列到交换机
err = ch.QueueBind(
queue.Name, // queue name
routingKey, // routing key
exchangeName, // exchange
false,
nil,
)
if err != nil {
return gerror.Wrapf(err, "绑定队列失败")
}
// 发布延时消息
err = ch.PublishWithContext(
ctx,
exchangeName, // exchange
routingKey, // routing key
false, // mandatory
false, // immediate
amqp.Publishing{
ContentType: "application/json",
Body: body,
DeliveryMode: amqp.Persistent, // 持久化消息
Headers: amqp.Table{
"x-delay": delaySeconds * 1000, // 延时时间(毫秒)
},
Timestamp: time.Now(),
},
)
if err != nil {
return gerror.Wrapf(err, "发布延时消息失败")
}
return nil
}

View File

@@ -1,59 +0,0 @@
package rabbitmq
import (
"context"
"fmt"
"os"
"sync"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/guid"
)
var (
instanceId string
instanceOnce sync.Once
)
// getInstanceId 获取当前实例的唯一标识(单例)
// 优先级:配置文件 > 环境变量 > 容器名/主机名 > 随机UUID
func getInstanceId() string {
instanceOnce.Do(func() {
ctx := context.Background()
// 1. 优先从配置文件读取(手动指定,最高优先级)
instanceId = g.Cfg().MustGet(ctx, "rabbitmq.instanceName").String()
if instanceId != "" {
return
}
// 2. 读取环境变量Docker/K8s部署时设置
instanceId = os.Getenv("INSTANCE_NAME")
if instanceId != "" {
return
}
// 3. 使用主机名Docker容器名/主机名)
hostname, err := os.Hostname()
if err != nil || hostname == "" {
hostname = "unknown"
}
// 4. 如果主机名是默认值(本地开发),添加随机后缀避免冲突
if hostname == "localhost" || hostname == "unknown" {
instanceId = hostname + "." + guid.S()[:4]
} else {
instanceId = hostname
}
})
return instanceId
}
// GetInstanceQueueName 获取当前实例的响应队列名
// 格式:{baseQueue}.{hostname}.{uuid8}
func GetInstanceQueueName(baseQueue string) string {
if baseQueue == "" {
baseQueue = "ragflow.response"
}
return fmt.Sprintf("%s.%s", baseQueue, getInstanceId())
}

View File

@@ -1,152 +0,0 @@
package rabbitmq
import (
"context"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
amqp "github.com/rabbitmq/amqp091-go"
)
// Publisher 消息发布器
type Publisher struct {
exchange string
routingKey string
}
// NewPublisher 创建发布器
func NewPublisher(exchange, routingKey string) *Publisher {
return &Publisher{
exchange: exchange,
routingKey: routingKey,
}
}
// Publish 发布消息(使用默认 routing key
func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error) {
return p.PublishWithRoutingKey(ctx, p.routingKey, message)
}
// PublishWithRoutingKey 发布消息(指定 routing key
func (p *Publisher) PublishWithRoutingKey(ctx context.Context, routingKey string, message interface{}) (err error) {
ch, err := GetChannel()
if err != nil {
return err
}
// 序列化消息
body, err := gjson.Encode(message)
if err != nil {
return gerror.Newf("消息序列化失败: %v", err)
}
// 发布消息
err = ch.PublishWithContext(
ctx,
p.exchange, // exchange
routingKey, // routing key
false, // mandatory
false, // immediate
amqp.Publishing{
DeliveryMode: amqp.Persistent, // 持久化
ContentType: "application/json",
Body: body,
},
)
if err != nil {
g.Log().Errorf(ctx, "发布消息失败: exchange=%s, routingKey=%s, err=%v",
p.exchange, routingKey, err)
return err
}
g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s",
p.exchange, routingKey)
return
}
// PublishDelayed 发布延时消息
// delaySeconds: 延时秒数
func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) (err error) {
ch, err := GetChannel()
if err != nil {
return err
}
// 序列化消息
body, err := gjson.Encode(message)
if err != nil {
return gerror.Newf("消息序列化失败: %v", err)
}
// 发布延时消息(需要 rabbitmq_delayed_message_exchange 插件)
err = ch.PublishWithContext(
ctx,
p.exchange, // exchange必须是 x-delayed-message 类型)
p.routingKey, // routing key
false, // mandatory
false, // immediate
amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "application/json",
Body: body,
Headers: amqp.Table{
"x-delay": delaySeconds * 1000, // 延时(毫秒)
},
},
)
if err != nil {
g.Log().Errorf(ctx, "发布延时消息失败: exchange=%s, routingKey=%s, delay=%ds, err=%v",
p.exchange, p.routingKey, delaySeconds, err)
return err
}
g.Log().Debugf(ctx, "延时消息发布成功: exchange=%s, routingKey=%s, delay=%ds",
p.exchange, p.routingKey, delaySeconds)
return
}
// PublishBatch 批量发布消息
func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) (err error) {
if len(messages) == 0 {
return
}
ch, err := GetChannel()
if err != nil {
return err
}
for i, message := range messages {
body, err := gjson.Encode(message)
if err != nil {
g.Log().Errorf(ctx, "消息 %d 序列化失败: %v", i, err)
continue
}
err = ch.PublishWithContext(
ctx,
p.exchange,
p.routingKey,
false,
false,
amqp.Publishing{
DeliveryMode: amqp.Persistent,
ContentType: "application/json",
Body: body,
},
)
if err != nil {
g.Log().Errorf(ctx, "消息 %d 发布失败: %v", i, err)
continue
}
}
g.Log().Infof(ctx, "批量发布完成: 共 %d 条消息", len(messages))
return
}

View File

@@ -1,111 +0,0 @@
// Package rabbitmq 提供 RabbitMQ 队列初始化的封装方法
//
// 本文件包含常用队列的声明和绑定逻辑,简化业务层的队列配置代码
package rabbitmq
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
)
// SetupResponseQueue 初始化 RAGFlow 响应队列
//
// 功能:
// 1. 声明持久化队列(从配置文件读取队列名,默认 ragflow.response.queue
// 2. 绑定到 ragflow.response ExchangeTopic 类型)
// 3. 使用通配符 # 匹配所有 routing keyuserId
//
// 参数:
//
// ctx: 上下文
//
// 返回:
//
// err: 错误信息,成功返回 nil
//
// 配置示例config.yml
//
// rabbitmq:
// responseQueue: "ragflow.response.queue" # 可选,默认值
func SetupResponseQueue(ctx context.Context) (err error) {
// 从配置文件读取队列名(支持每个开发者配置独立队列名)
responseQueue := g.Cfg().MustGet(ctx, "rabbitmq.responseQueue", "ragflow.response.queue").String()
// 声明持久化队列(服务器重启后队列仍存在)
if err = DeclareQueue(ctx, &QueueConfig{
Name: responseQueue,
Durable: true, // 持久化,防止数据丢失
}); err != nil {
glog.Errorf(ctx, "声明响应队列失败: %v", err)
return
}
// 绑定队列到 Exchange
// Exchange 类型为 topicrouting key 使用通配符 # 匹配所有 userId
if err = BindQueue(ctx, &BindingConfig{
Queue: responseQueue,
Exchange: "ragflow.response", // RAGFlow 响应 Exchange
RoutingKey: "#", // 通配符,匹配所有消息
}); err != nil {
glog.Errorf(ctx, "绑定响应队列失败: %v", err)
return
}
glog.Infof(ctx, "响应队列已绑定: %s -> ragflow.response (routingKey=#)", responseQueue)
return
}
// SetupDelayedFlushQueue 初始化延时落库队列
//
// 功能:
// 1. 声明延时 Exchangex-delayed-message 插件)
// 2. 声明持久化队列 conversation.flush.queue
// 3. 绑定队列到延时 Exchange
//
// 用途:
//
// 对话缓存延时落库机制的兜底策略
// 当对话少于5句时10分钟后触发延时消息将缓存写入MongoDB
//
// 参数:
//
// ctx: 上下文
//
// 返回:
//
// err: 错误信息,成功返回 nil
//
// 相关:
// - service/conversation_service.go: handleResponse()
// - service/conversation_service.go: handleDelayedFlush()
func SetupDelayedFlushQueue(ctx context.Context) (err error) {
// 声明延时 Exchange需要 RabbitMQ 安装 x-delayed-message 插件)
if err = SetupDelayExchange(ctx, "conversation.flush.delayed"); err != nil {
glog.Warningf(ctx, "声明延时落库 Exchange 失败: %v", err)
return
}
// 声明持久化队列
if err = DeclareQueue(ctx, &QueueConfig{
Name: "conversation.flush.queue",
Durable: true, // 持久化,防止延时消息丢失
}); err != nil {
glog.Warningf(ctx, "声明延时落库 Queue 失败: %v", err)
return
}
// 绑定队列到延时 Exchange
if err = BindQueue(ctx, &BindingConfig{
Queue: "conversation.flush.queue",
Exchange: "conversation.flush.delayed",
RoutingKey: "flush", // 延时落库消息的 routing key
}); err != nil {
glog.Warningf(ctx, "绑定延时落库 Queue 失败: %v", err)
return
}
glog.Info(ctx, "延时落库队列已配置")
return
}

View File

@@ -1,210 +0,0 @@
package rabbitmq
import (
"context"
"sync"
"time"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
amqp "github.com/rabbitmq/amqp091-go"
)
var (
conn *amqp.Connection
channel *amqp.Channel
rabbitmqOnce sync.Once
rabbitmqMu sync.RWMutex
closeWatcher chan struct{} // 用于停止监听 goroutine
watcherStarted bool // 防止重复启动监听
)
// Config RabbitMQ 配置
type Config struct {
Host string
Port int
Username string
Password string
VHost string
}
// Init 初始化 RabbitMQ 连接
func Init(ctx context.Context, cfg *Config) error {
var err error
rabbitmqOnce.Do(func() {
// 构建连接字符串
url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost
// 创建连接
conn, err = amqp.Dial(url)
if err != nil {
g.Log().Errorf(ctx, "RabbitMQ 连接失败: %v", err)
return
}
// 创建 Channel
channel, err = conn.Channel()
if err != nil {
g.Log().Errorf(ctx, "创建 RabbitMQ Channel 失败: %v", err)
return
}
// 初始化关闭监听器
closeWatcher = make(chan struct{})
// 监听连接关闭(只启动一次)
if !watcherStarted {
go handleConnectionClose(ctx)
watcherStarted = true
}
g.Log().Info(ctx, "RabbitMQ 连接成功")
})
return err
}
// InitFromConfig 从配置文件初始化
func InitFromConfig(ctx context.Context) error {
cfg := &Config{
Host: g.Cfg().MustGet(ctx, "rabbitmq.host").String(),
Port: g.Cfg().MustGet(ctx, "rabbitmq.port").Int(),
Username: g.Cfg().MustGet(ctx, "rabbitmq.username").String(),
Password: g.Cfg().MustGet(ctx, "rabbitmq.password").String(),
VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(),
}
return Init(ctx, cfg)
}
// GetChannel 获取 Channel
func GetChannel() (*amqp.Channel, error) {
rabbitmqMu.RLock()
defer rabbitmqMu.RUnlock()
if channel == nil || channel.IsClosed() {
return nil, gerror.New("RabbitMQ Channel 未初始化或已关闭")
}
return channel, nil
}
// GetConnection 获取连接
func GetConnection() (*amqp.Connection, error) {
rabbitmqMu.RLock()
defer rabbitmqMu.RUnlock()
if conn == nil || conn.IsClosed() {
return nil, gerror.New("RabbitMQ 连接未初始化或已关闭")
}
return conn, nil
}
// handleConnectionClose 监听连接关闭并重连
func handleConnectionClose(ctx context.Context) {
for {
// 检查是否需要停止监听
select {
case <-closeWatcher:
g.Log().Info(ctx, "停止监听 RabbitMQ 连接状态")
return
default:
}
rabbitmqMu.RLock()
currentConn := conn
rabbitmqMu.RUnlock()
if currentConn == nil {
return
}
// 创建关闭通知 channel
closeErr := make(chan *amqp.Error, 1)
currentConn.NotifyClose(closeErr)
// 等待连接关闭或停止信号
select {
case err := <-closeErr:
if err != nil {
g.Log().Errorf(ctx, "RabbitMQ 连接关闭: %v尝试重连...", err)
reconnect(ctx)
}
case <-closeWatcher:
g.Log().Info(ctx, "停止监听 RabbitMQ 连接状态")
return
}
}
}
// reconnect 重新连接
func reconnect(ctx context.Context) {
rabbitmqMu.Lock()
defer rabbitmqMu.Unlock()
for i := 0; i < 10; i++ {
time.Sleep(time.Duration(i+1) * time.Second)
cfg := &Config{
Host: g.Cfg().MustGet(ctx, "rabbitmq.host").String(),
Port: g.Cfg().MustGet(ctx, "rabbitmq.port").Int(),
Username: g.Cfg().MustGet(ctx, "rabbitmq.username").String(),
Password: g.Cfg().MustGet(ctx, "rabbitmq.password").String(),
VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(),
}
url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost
var err error
conn, err = amqp.Dial(url)
if err != nil {
g.Log().Errorf(ctx, "重连失败 (尝试 %d/10): %v", i+1, err)
continue
}
channel, err = conn.Channel()
if err != nil {
g.Log().Errorf(ctx, "创建 Channel 失败 (尝试 %d/10): %v", i+1, err)
continue
}
g.Log().Info(ctx, "RabbitMQ 重连成功")
// 不再重复启动监听 goroutine
return
}
g.Log().Fatal(ctx, "RabbitMQ 重连失败,已达到最大重试次数")
}
// Close 关闭连接
func Close(ctx context.Context) (err error) {
rabbitmqMu.Lock()
defer rabbitmqMu.Unlock()
// 停止监听 goroutine
if closeWatcher != nil {
close(closeWatcher)
closeWatcher = nil
}
if channel != nil {
if err = channel.Close(); err != nil {
g.Log().Errorf(ctx, "关闭 RabbitMQ Channel 失败: %v", err)
}
channel = nil
}
if conn != nil {
if err = conn.Close(); err != nil {
g.Log().Errorf(ctx, "关闭 RabbitMQ 连接失败: %v", err)
return
}
conn = nil
}
watcherStarted = false
g.Log().Info(ctx, "RabbitMQ 连接已关闭")
return
}

View File

@@ -1,231 +0,0 @@
package rabbitmq
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
amqp "github.com/rabbitmq/amqp091-go"
)
// QueueConfig 队列配置
type QueueConfig struct {
Name string
Durable bool // 持久化
AutoDelete bool // 自动删除
Exclusive bool // 排他
Args amqp.Table // 额外参数
}
// ExchangeConfig Exchange 配置
type ExchangeConfig struct {
Name string
Type string // direct/topic/fanout/x-delayed-message
Durable bool
AutoDelete bool
Args amqp.Table
}
// BindingConfig 绑定配置
type BindingConfig struct {
Queue string
Exchange string
RoutingKey string
Args amqp.Table
}
// DeclareQueue 声明队列
func DeclareQueue(ctx context.Context, cfg *QueueConfig) (err error) {
ch, err := GetChannel()
if err != nil {
return err
}
_, err = ch.QueueDeclare(
cfg.Name,
cfg.Durable,
cfg.AutoDelete,
cfg.Exclusive,
false, // no-wait
cfg.Args,
)
if err != nil {
g.Log().Errorf(ctx, "声明队列失败: %s, err=%v", cfg.Name, err)
return err
}
g.Log().Infof(ctx, "队列声明成功: %s", cfg.Name)
return
}
// DeclareExchange 声明 Exchange
func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) (err error) {
ch, err := GetChannel()
if err != nil {
return err
}
err = ch.ExchangeDeclare(
cfg.Name,
cfg.Type,
cfg.Durable,
cfg.AutoDelete,
false, // internal
false, // no-wait
cfg.Args,
)
if err != nil {
g.Log().Errorf(ctx, "声明 Exchange 失败: %s, err=%v", cfg.Name, err)
return err
}
g.Log().Infof(ctx, "Exchange 声明成功: %s (type=%s)", cfg.Name, cfg.Type)
return
}
// BindQueue 绑定队列到 Exchange
func BindQueue(ctx context.Context, cfg *BindingConfig) (err error) {
ch, err := GetChannel()
if err != nil {
return err
}
err = ch.QueueBind(
cfg.Queue,
cfg.RoutingKey,
cfg.Exchange,
false, // no-wait
cfg.Args,
)
if err != nil {
g.Log().Errorf(ctx, "绑定队列失败: queue=%s, exchange=%s, routingKey=%s, err=%v",
cfg.Queue, cfg.Exchange, cfg.RoutingKey, err)
return err
}
g.Log().Infof(ctx, "队列绑定成功: queue=%s → exchange=%s (routingKey=%s)",
cfg.Queue, cfg.Exchange, cfg.RoutingKey)
return
}
// SetupDelayExchange 设置延时 Exchange需要 rabbitmq_delayed_message_exchange 插件)
func SetupDelayExchange(ctx context.Context, exchangeName string) error {
return DeclareExchange(ctx, &ExchangeConfig{
Name: exchangeName,
Type: "x-delayed-message",
Durable: true,
Args: amqp.Table{
"x-delayed-type": "direct",
},
})
}
// SetupDeadLetterQueue 设置死信队列
func SetupDeadLetterQueue(ctx context.Context, queueName, exchangeName string) error {
// 1. 声明死信 Exchange
err := DeclareExchange(ctx, &ExchangeConfig{
Name: exchangeName,
Type: "direct",
Durable: true,
})
if err != nil {
return err
}
// 2. 声明死信队列
err = DeclareQueue(ctx, &QueueConfig{
Name: queueName,
Durable: true,
})
if err != nil {
return err
}
// 3. 绑定
return BindQueue(ctx, &BindingConfig{
Queue: queueName,
Exchange: exchangeName,
RoutingKey: queueName,
})
}
// SetupQueueWithDLX 创建带死信队列的普通队列
func SetupQueueWithDLX(ctx context.Context, queueName, dlxExchange, dlxRoutingKey string) error {
return DeclareQueue(ctx, &QueueConfig{
Name: queueName,
Durable: true,
Args: amqp.Table{
"x-dead-letter-exchange": dlxExchange,
"x-dead-letter-routing-key": dlxRoutingKey,
},
})
}
// SetupBasicTopology 设置基础拓扑(适用于小红书客服场景)
func SetupBasicTopology(ctx context.Context) (err error) {
// 1. 声明普通 Exchange
err = DeclareExchange(ctx, &ExchangeConfig{
Name: "ragflow_exchange",
Type: "direct",
Durable: true,
})
if err != nil {
return err
}
// 2. 声明延时 Exchange
err = SetupDelayExchange(ctx, "delay_exchange")
if err != nil {
return gerror.Newf("延时 Exchange 声明失败(可能未安装插件): %v", err)
}
// 3. 声明死信队列
err = SetupDeadLetterQueue(ctx, "dead_letter_queue", "dlx_exchange")
if err != nil {
return err
}
// 4. 声明业务队列
queues := []struct {
name string
dlx bool // 是否需要死信队列
}{
{"ragflow_request_queue", true},
{"follow_up_queue", true},
{"archive_queue", true},
}
for _, q := range queues {
if q.dlx {
err = SetupQueueWithDLX(ctx, q.name, "dlx_exchange", "dead_letter_queue")
} else {
err = DeclareQueue(ctx, &QueueConfig{
Name: q.name,
Durable: true,
})
}
if err != nil {
return err
}
}
// 5. 绑定队列
bindings := []BindingConfig{
{Queue: "ragflow_request_queue", Exchange: "ragflow_exchange", RoutingKey: "ragflow_request_queue"},
{Queue: "follow_up_queue", Exchange: "delay_exchange", RoutingKey: "follow_up_queue"},
{Queue: "archive_queue", Exchange: "delay_exchange", RoutingKey: "archive_queue"},
}
for _, b := range bindings {
err = BindQueue(ctx, &b)
if err != nil {
return err
}
}
g.Log().Info(ctx, "RabbitMQ 拓扑结构设置完成")
return
}

View File

@@ -1,141 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
)
// Agent AGENT 管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#agent-管理
// Agent Agent 结构体
type Agent struct {
ID string `json:"id"` // Agent ID
Title string `json:"title"` // Agent 标题
Description string `json:"description"` // Agent 描述
Avatar string `json:"avatar"` // 头像Base64 编码)
CanvasType string `json:"canvas_type"` // 画布类型
CreateDate string `json:"create_date"` // 创建日期(格式化字符串)
CreateTime int64 `json:"create_time"` // 创建时间Unix 时间戳)
UpdateDate string `json:"update_date"` // 更新日期(格式化字符串)
UpdateTime int64 `json:"update_time"` // 更新时间Unix 时间戳)
UserID string `json:"user_id"` // 用户 ID
DSL map[string]interface{} `json:"dsl"` // Canvas DSL 对象,定义 Agent 的工作流
}
// CreateAgentReq 创建 Agent 请求
type CreateAgentReq struct {
Title string `json:"title"` // 必需
Description string `json:"description,omitempty"` // 可选,默认为 None
DSL map[string]interface{} `json:"dsl"` // 必需Canvas DSL 对象
}
// UpdateAgentReq 更新 Agent 请求
type UpdateAgentReq struct {
Title string `json:"title,omitempty"`
Description string `json:"description,omitempty"`
DSL map[string]interface{} `json:"dsl,omitempty"`
}
// ListAgentsReq 列出 Agent 请求
type ListAgentsReq struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OrderBy string `json:"orderby,omitempty"`
Desc bool `json:"desc,omitempty"`
Title string `json:"title,omitempty"`
ID string `json:"id,omitempty"`
}
// ListAgentsRes 列出 Agent 响应
// 注意API 不返回 total 字段,仅返回 data 数组
type ListAgentsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data []*Agent `json:"data"` // Agent 列表
}
// CreateAgent 创建 Agent
// POST /api/v1/agents
func (c *Client) CreateAgent(ctx context.Context, req *CreateAgentReq) (err error) {
var res CommonResponse
if err = c.request(ctx, "POST", "/api/v1/agents", req, &res); err != nil {
return gerror.Newf("create agent failed: %v", err)
}
if !res.IsSuccess() {
return gerror.Newf("create agent failed: %s", res.Message)
}
return
}
// UpdateAgent 更新 Agent
// PUT /api/v1/agents/{agent_id}
func (c *Client) UpdateAgent(ctx context.Context, agentID string, req *UpdateAgentReq) (err error) {
path := "/api/v1/agents/" + agentID
var res CommonResponse
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
return gerror.Newf("update agent failed: %v", err)
}
if !res.IsSuccess() {
return gerror.Newf("update agent failed: %s", res.Message)
}
return
}
// DeleteAgent 删除 Agent
// DELETE /api/v1/agents/{agent_id}
func (c *Client) DeleteAgent(ctx context.Context, agentID string) (err error) {
path := "/api/v1/agents/" + agentID
var res CommonResponse
// 官方文档要求传空对象,不是 nil
if err = c.request(ctx, "DELETE", path, map[string]interface{}{}, &res); err != nil {
return gerror.Newf("delete agent failed: %v", err)
}
if !res.IsSuccess() {
return gerror.Newf("delete agent failed: %s", res.Message)
}
return
}
// ListAgents 列出 Agent
// GET /api/v1/agents
func (c *Client) ListAgents(ctx context.Context, req *ListAgentsReq) (*ListAgentsRes, error) {
path := "/api/v1/agents"
if req != nil {
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Title != "" {
params["title"] = req.Title
}
if req.ID != "" {
params["id"] = req.ID
}
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
}
var res ListAgentsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, gerror.Newf("list agents failed: %v", err)
}
if res.Code != 0 {
return nil, gerror.Newf("list agents failed: code=%d", res.Code)
}
return &res, nil
}

View File

@@ -1,198 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
)
// CreateChatReq 创建对话配置请求
type CreateChatReq struct {
Name string `json:"name"` // 对话配置名称(助理姓名)
Description string `json:"description,omitempty"` // 助理描述
DatasetIds []string `json:"dataset_ids"` // 关联的知识库ID列表
Prompt *PromptConfig `json:"prompt"` // 提示词配置
Llm *Llm `json:"llm,omitempty"` // LLM配置
}
// PromptConfig 提示词配置
type PromptConfig struct {
Prompt string `json:"prompt"` // 提示词内容
SimilarityThreshold float64 `json:"similarity_threshold"` // 相似度阈值
KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight"` // 关键词相似度权重
TopN int `json:"top_n"` // 返回顶部N个chunk
EmptyResponse string `json:"empty_response"` // 无匹配时回复必须显式传入空字符串才能让LLM自由发挥不传入会使用RAGFlow默认提示词
Opener string `json:"opener,omitempty"` // 开场白
ShowQuote bool `json:"show_quote,omitempty"` // 是否显示引用
Variables []map[string]interface{} `json:"variables,omitempty"` // 变量列表
}
// CreateChatRes 创建对话配置响应
type CreateChatRes struct {
ChatId string `json:"id"` // 对话配置ID
}
// UpdateChatReq 更新对话配置请求
type UpdateChatReq struct {
Name string `json:"name,omitempty"` // 对话配置名称
Description string `json:"description,omitempty"` // 对话描述
DatasetIds []string `json:"dataset_ids,omitempty"` // 关联的知识库ID列表RAGFlow API使用下划线格式
Prompt *PromptConfig `json:"prompt,omitempty"` // 提示词配置
}
// 聊天助手管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#聊天助手管理
// Chat 聊天助手结构体
type Chat struct {
Id string `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
DatasetIds []string `json:"dataset_ids"`
Llm Llm `json:"llm"`
Prompt Prompt `json:"prompt"`
Description string `json:"description"`
DoRefer string `json:"do_refer"`
Language string `json:"language"`
PromptType string `json:"prompt_type"`
Status string `json:"status"`
TenantId string `json:"tenant_id"`
TopK int `json:"top_k"`
CreateDate string `json:"create_date"`
CreateTime int64 `json:"create_time"`
UpdateDate string `json:"update_date"`
UpdateTime int64 `json:"update_time"`
}
type Llm struct {
ModelName string `json:"model_name,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
}
type Prompt struct {
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight,omitempty"`
Opener string `json:"opener,omitempty"`
Prompt string `json:"prompt,omitempty"`
RerankModel string `json:"rerank_model,omitempty"`
TopN int `json:"top_n,omitempty"`
Variables []Variable `json:"variables,omitempty"`
EmptyResponse string `json:"empty_response,omitempty"`
}
type Variable struct {
Key string `json:"key"`
Optional bool `json:"optional"`
}
// ListChatsReq 列出聊天助手请求
type ListChatsReq struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OrderBy string `json:"orderby,omitempty"`
Desc bool `json:"desc,omitempty"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
}
// ListChatsRes 列出聊天助手响应
// 注意API 不返回 total 字段,仅返回 data 数组
type ListChatsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data []*Chat `json:"data"` // 聊天助手列表
}
// DeleteChatsReq 删除聊天助手请求
type DeleteChatsReq struct {
Ids []string `json:"ids"`
}
// CreateChat 创建聊天助手
func (c *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*Chat, error) {
var res struct {
Code int `json:"code"`
Data *Chat `json:"data"`
Msg string `json:"message"`
}
if err := c.request(ctx, "POST", "/api/v1/chats", req, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("create chat failed: %s", res.Msg)
}
// 检查响应数据是否为空防止RAGFlow API返回 {"code":0, "data":null}
// 如果不检查直接返回,调用方会收到 (nil, nil),导致空指针异常
if res.Data == nil {
return nil, gerror.Newf("create chat returned null data: %s", res.Msg)
}
return res.Data, nil
}
// ListChats 列出聊天助手
func (c *Client) ListChats(ctx context.Context, req *ListChatsReq) (*ListChatsRes, error) {
path := "/api/v1/chats"
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Name != "" {
params["name"] = req.Name
}
if req.Id != "" {
params["id"] = req.Id
}
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
var res ListChatsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list chats failed: code=%d", res.Code)
}
return &res, nil
}
// DeleteChats 删除聊天助手
func (c *Client) DeleteChats(ctx context.Context, ids []string) (err error) {
req := DeleteChatsReq{Ids: ids}
var res CommonResponse
if err = c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete chats failed: %s", res.Message)
}
return
}
// UpdateChat 更新聊天助手
func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) (err error) {
var res CommonResponse
path := "/api/v1/chats/" + id
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("update chat failed: %s", res.Message)
}
return
}

View File

@@ -1,180 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
)
// 数据集内知识块管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#数据集内知识块管理
// Chunk 知识块结构体
type Chunk struct {
Id string `json:"id"`
Content string `json:"content"`
DocumentId string `json:"document_id"`
DatasetId string `json:"dataset_id"`
CreateTime string `json:"create_time"`
CreateTimestamp float64 `json:"create_timestamp"`
ImportantKeywords []string `json:"important_keywords"`
Questions []string `json:"questions"`
Available bool `json:"available"`
ImageId string `json:"image_id"`
Positions []string `json:"positions"`
}
// AddChunkReq 添加知识块请求
type AddChunkReq struct {
Content string `json:"content"`
ImportantKeywords []string `json:"important_keywords,omitempty"`
Questions []string `json:"questions,omitempty"`
}
// ListChunksReq 列出知识块请求
type ListChunksReq struct {
Keywords string `json:"keywords,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
Id string `json:"id,omitempty"`
}
// ListChunksRes 列出知识块响应
// 注意:响应结构包含 chunks知识块列表、doc关联文档信息和 total总数
type ListChunksRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data struct {
Chunks []*Chunk `json:"chunks"` // 知识块列表
Doc interface{} `json:"doc"` // 关联文档信息(完整的 Document 对象)
Total int `json:"total"` // 知识块总数
} `json:"data"`
}
// DeleteChunksReq 删除知识块请求
type DeleteChunksReq struct {
ChunkIds []string `json:"chunk_ids,omitempty"` // 如果为空,删除所有
}
// UpdateChunkReq 更新知识块请求
type UpdateChunkReq struct {
Content string `json:"content,omitempty"`
ImportantKeywords []string `json:"important_keywords,omitempty"`
Available *bool `json:"available,omitempty"`
}
// RetrieveChunksReq 检索知识块请求
type RetrieveChunksReq struct {
Question string `json:"question"`
DatasetIds []string `json:"dataset_ids,omitempty"`
DocumentIds []string `json:"document_ids,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"`
TopK int `json:"top_k,omitempty"`
RerankId string `json:"rerank_id,omitempty"`
Keyword bool `json:"keyword,omitempty"`
Highlight bool `json:"highlight,omitempty"`
CrossLanguages []string `json:"cross_languages,omitempty"`
MetadataCondition map[string]interface{} `json:"metadata_condition,omitempty"`
}
// RetrieveChunksRes 检索知识块响应 (结构比较复杂,暂时简化,根据实际返回调整)
// 官方文档未给出详细响应结构,假设返回 chunks 列表
type RetrieveChunksRes struct {
Code int `json:"code"`
Data struct {
Chunks []interface{} `json:"chunks"` // 检索结果可能包含额外信息
Total int `json:"total"`
} `json:"data"`
}
// AddChunk 添加知识块
func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req *AddChunkReq) (*Chunk, error) {
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
var res struct {
Code int `json:"code"`
Data struct {
Chunk *Chunk `json:"chunk"`
} `json:"data"`
Msg string `json:"message"`
}
if err := c.request(ctx, "POST", path, req, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("add chunk failed: %s", res.Msg)
}
return res.Data.Chunk, nil
}
// ListChunks 列出知识块
func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, req *ListChunksReq) (*ListChunksRes, error) {
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
params := map[string]interface{}{}
if req.Keywords != "" {
params["keywords"] = req.Keywords
}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.Id != "" {
params["id"] = req.Id
}
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
var res ListChunksRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list chunks failed: code=%d", res.Code)
}
return &res, nil
}
// DeleteChunks 删除知识块
func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) (err error) {
req := DeleteChunksReq{ChunkIds: chunkIds}
var res CommonResponse
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks"
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete chunks failed: %s", res.Message)
}
return
}
// UpdateChunk 更新知识块
func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) (err error) {
var res CommonResponse
path := "/api/v1/datasets/" + datasetId + "/documents/" + documentId + "/chunks/" + chunkId
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("update chunk failed: %s", res.Message)
}
return
}
// RetrieveChunks 检索知识块
func (c *Client) RetrieveChunks(ctx context.Context, req *RetrieveChunksReq) (*RetrieveChunksRes, error) {
var res RetrieveChunksRes
if err := c.request(ctx, "POST", "/api/v1/retrieval", req, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("retrieve chunks failed: code=%d", res.Code)
}
return &res, nil
}

View File

@@ -1,195 +0,0 @@
package ragflow
import (
"context"
"encoding/json"
"net/url"
"strings"
"sync"
"sync/atomic"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gclient"
)
var (
// globalClient 全局 RAGFlow 客户端(单例,延迟初始化)
globalClient *Client
clientOnce sync.Once
)
// initClient 延迟初始化客户端
func initClient() {
clientOnce.Do(func() {
ctx := context.Background()
// 读取配置
endpoints, apiKey := loadConfig(ctx)
// 如果配置不完整,跳过初始化
if len(endpoints) == 0 || apiKey == "" {
g.Log().Warning(ctx, "⚠️ RAGFlow 配置未找到,请在 config.yml 中添加 ragflow.base_url 或在 Consul 中配置 ragflow.endpoints")
return
}
globalClient = &Client{
Endpoints: endpoints,
APIKey: apiKey,
}
if len(endpoints) == 1 {
g.Log().Infof(ctx, "✅ RAGFlow 客户端初始化成功: endpoint=%s", endpoints[0])
} else {
g.Log().Infof(ctx, "✅ RAGFlow 客户端初始化成功: endpoints=%v (负载均衡)", endpoints)
}
})
}
// loadConfig 从配置加载 RAGFlow 配置(支持实例级配置)
// 优先级:
// 1. Consul实例级配置 ragflow.endpoints (数组)
// 2. Consul全局配置 ragflow.endpoints (数组)
// 3. config.yml的 ragflow.base_url (单个URL向后兼容)
func loadConfig(ctx context.Context) (endpoints []string, apiKey string) {
// 尝试从Consul读取endpoints支持实例级配置
// 注意这里不能直接导入customerService/service包会造成循环依赖
// 所以只能从config.yml读取Consul配置需要在customerservice层面调用时传入
// 读取API Key
apiKey = g.Cfg().MustGet(ctx, "ragflow.api_key", "").String()
// 尝试读取endpoints数组从config.yml或Consul同步的配置
endpointsConfig := g.Cfg().MustGet(ctx, "ragflow.endpoints")
if !endpointsConfig.IsEmpty() {
endpoints = endpointsConfig.Strings()
// 去除尾部斜杠
for i := range endpoints {
endpoints[i] = strings.TrimSuffix(endpoints[i], "/")
}
return
}
// Fallback到单个base_url向后兼容
baseURL := g.Cfg().MustGet(ctx, "ragflow.base_url", "").String()
if baseURL != "" {
endpoints = []string{strings.TrimSuffix(baseURL, "/")}
}
return
}
// GetGlobalClient 获取全局客户端(延迟初始化)
func GetGlobalClient() *Client {
initClient()
return globalClient
}
// Client RAGFlow API 客户端(支持负载均衡)
type Client struct {
Endpoints []string // RAGFlow实例列表
APIKey string // API密钥
currentIndex atomic.Uint64 // 当前轮询索引(原子操作)
}
// getNextEndpoint 获取下一个endpoint轮询算法
func (c *Client) getNextEndpoint() string {
if len(c.Endpoints) == 0 {
return ""
}
if len(c.Endpoints) == 1 {
return c.Endpoints[0]
}
// 原子递增并取模,实现轮询
idx := c.currentIndex.Add(1) % uint64(len(c.Endpoints))
return c.Endpoints[idx]
}
// CommonResponse 通用响应结构
type CommonResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// IsSuccess 检查响应是否成功
func (r *CommonResponse) IsSuccess() bool {
return r.Code == 0
}
// request 发送 HTTP 请求
//
// 为什么不使用 common/http 包:
// common/http包统一处理内部API响应格式ghttp.DefaultHandlerResponse
// RAGFlow API返回格式为{code,data,message}一层结构与内部API不同。
// 因此直接使用 g.Client() 调用第三方API在此处理RAGFlow特有的响应格式。
func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) (err error) {
endpoint := c.getNextEndpoint()
if endpoint == "" {
return gerror.New("RAGFlow endpoints not configured")
}
fullURL := endpoint + path
// 创建HTTP客户端并设置RAGFlow专用请求头
client := g.Client()
client.SetHeader("Authorization", "Bearer "+c.APIKey)
client.SetHeader("Content-Type", "application/json")
// 发送HTTP请求避免data展开导致的双重包装
var response *gclient.Response
switch method {
case "GET":
if body != nil {
response, err = client.Get(ctx, fullURL, body)
} else {
response, err = client.Get(ctx, fullURL)
}
case "POST":
if body != nil {
response, err = client.Post(ctx, fullURL, body)
} else {
response, err = client.Post(ctx, fullURL)
}
case "PUT":
if body != nil {
response, err = client.Put(ctx, fullURL, body)
} else {
response, err = client.Put(ctx, fullURL)
}
case "DELETE":
if body != nil {
response, err = client.Delete(ctx, fullURL, body)
} else {
response, err = client.Delete(ctx, fullURL)
}
default:
return gerror.Newf("unsupported method: %s", method)
}
if err != nil {
return
}
defer response.Close()
// RAGFlow API响应格式{code,data,message}一层结构,直接解析
responseBody := response.ReadAll()
if err = json.Unmarshal(responseBody, result); err != nil {
return gerror.Newf("RAGFlow响应解析失败: %v, 原始响应: %s", err, string(responseBody))
}
return
}
// buildQueryString 构建查询字符串
func buildQueryString(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
parts := make([]string, 0, len(params))
for k, v := range params {
parts = append(parts, url.QueryEscape(k)+"="+url.QueryEscape(g.NewVar(v).String()))
}
return strings.Join(parts, "&")
}

View File

@@ -1,190 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
// 数据集管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#数据集管理
// Dataset 数据集结构体
type Dataset struct {
Id string `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
TenantId string `json:"tenant_id"`
Description string `json:"description"`
Language string `json:"language"`
EmbeddingModel string `json:"embedding_model"`
Permission string `json:"permission"`
DocumentCount int `json:"document_count"`
ChunkCount int `json:"chunk_count"`
ParseStatus string `json:"parse_status"`
CreatedBy string `json:"created_by"`
CreateTime int64 `json:"create_time"`
UpdateDate string `json:"update_date"`
UpdateTime int64 `json:"update_time"`
Status string `json:"status"`
ChunkMethod string `json:"chunk_method"`
ParserConfig map[string]interface{} `json:"parser_config"`
VectorSimilarityWeight float64 `json:"vector_similarity_weight"`
SimilarityThreshold float64 `json:"similarity_threshold"`
TokenNum int `json:"token_num"`
}
// CreateDatasetReq 创建数据集请求
type CreateDatasetReq struct {
Name string `json:"name"`
Avatar string `json:"avatar,omitempty"`
Description string `json:"description,omitempty"`
EmbeddingModel string `json:"embedding_model,omitempty"`
Permission string `json:"permission,omitempty"`
ChunkMethod string `json:"chunk_method,omitempty"`
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
}
// UpdateDatasetReq 更新数据集请求
type UpdateDatasetReq struct {
Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"`
Description string `json:"description,omitempty"`
EmbeddingModel string `json:"embedding_model,omitempty"`
Permission string `json:"permission,omitempty"`
ChunkMethod string `json:"chunk_method,omitempty"`
PageRank int `json:"pagerank,omitempty"`
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
}
// ListDatasetsReq 列出数据集请求
type ListDatasetsReq struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OrderBy string `json:"orderby,omitempty"`
Desc bool `json:"desc,omitempty"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
}
// ListDatasetsRes 列出数据集响应
// 注意:与 Agent/Chat 等接口不同Dataset API 会返回 total 字段
type ListDatasetsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data []*Dataset `json:"data"` // 数据集列表
Total int `json:"total"` // 总数据集数
}
// DeleteDatasetsReq 删除数据集请求
type DeleteDatasetsReq struct {
Ids []string `json:"ids"`
}
// CreateDataset 创建数据集
func (c *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*Dataset, error) {
g.Log().Infof(ctx, "CreateDataset请求: name=%s, description=%s, embedding_model=%s", req.Name, req.Description, req.EmbeddingModel)
var res struct {
Code int `json:"code"`
Data *Dataset `json:"data"`
Msg string `json:"message"`
}
if err := c.request(ctx, "POST", "/api/v1/datasets", req, &res); err != nil {
g.Log().Errorf(ctx, "CreateDataset请求失败: %v", err)
return nil, err
}
g.Log().Infof(ctx, "CreateDataset响应: code=%d, msg=%s, data_is_nil=%v", res.Code, res.Msg, res.Data == nil)
// code=101表示dataset名称已存在正常业务场景不是错误
// 调用方应该通过ListDatasets查找已有dataset并复用
if res.Code == 101 {
return nil, gerror.Newf("Dataset名称已存在: %s", res.Msg)
}
// 其他非0的code表示真正的错误
if res.Code != 0 {
return nil, gerror.Newf("创建知识库失败(code=%d): %s", res.Code, res.Msg)
}
// code=0但data=null表示创建异常可能是RAGFlow配置问题如embedding模型不可用、权限不足等
// 这不是正常状态,应该返回错误而不是(nil, nil)
if res.Data == nil {
return nil, gerror.Newf("创建知识库返回空数据(code=0,data=null)可能是RAGFlow配置问题: %s", res.Msg)
}
g.Log().Infof(ctx, "CreateDataset成功: id=%s, name=%s", res.Data.Id, res.Data.Name)
return res.Data, nil
}
// ListDatasets 列出数据集
func (c *Client) ListDatasets(ctx context.Context, req *ListDatasetsReq) (*ListDatasetsRes, error) {
// 构建查询参数
path := "/api/v1/datasets"
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
// desc 默认为 true如果显式设置为 false 才传递,或者根据 API 行为调整
// 这里简单处理,如果设置了就传
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Name != "" {
params["name"] = req.Name
}
if req.Id != "" {
params["id"] = req.Id
}
// 拼接 query string
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
var res ListDatasetsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list datasets failed: code=%d", res.Code)
}
return &res, nil
}
// DeleteDataset 删除数据集
func (c *Client) DeleteDataset(ctx context.Context, ids []string) (err error) {
req := DeleteDatasetsReq{Ids: ids}
var res CommonResponse
if err = c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete dataset failed: %s", res.Message)
}
return
}
// UpdateDataset 更新数据集
func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) (err error) {
var res CommonResponse
path := "/api/v1/datasets/" + id
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("update dataset failed: %s", res.Message)
}
return
}

View File

@@ -1,274 +0,0 @@
// Package ragflow - RAGFlow文档管理
// 功能RAGFlow知识库文档的上传、列表、删除操作
package ragflow
import (
"bytes"
"context"
"encoding/json"
"mime/multipart"
"strings"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
// 数据集内文件管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#数据集内文件管理
// ... (rest of the code remains the same)
type Document struct {
Id string `json:"id"`
DatasetId string `json:"dataset_id"`
Name string `json:"name"`
Size int64 `json:"size"`
Location string `json:"location"`
CreatedBy string `json:"created_by"`
CreateTime int64 `json:"create_time"`
Thumbnail string `json:"thumbnail"`
Type string `json:"type"`
RunStatus string `json:"run_status"` // 对应 API 返回的 "run" 字段,可能需要确认
Status string `json:"status"`
ChunkMethod string `json:"chunk_method"`
ParserConfig map[string]interface{} `json:"parser_config"`
TokenNum int `json:"token_num"`
ChunkCount int `json:"chunk_count"`
ProcessBegin int64 `json:"process_begin"`
ProcessDu int64 `json:"process_du"`
Progress float64 `json:"progress"`
ProgressMsg string `json:"progress_msg"`
}
// UploadDocumentReq 上传文档请求
// 注意:上传文件通常需要 multipart/form-data这里仅定义结构实际逻辑在方法中处理
type UploadDocumentReq struct {
FilePaths []string // 本地文件路径列表
}
// UploadDocumentRes 上传文档响应
type UploadDocumentRes struct {
Id string `json:"id"` // 文档ID
}
// ListDocumentsReq 列出文档请求
type ListDocumentsReq struct {
Page int `json:"page,omitempty"` // 页码,默认 1
PageSize int `json:"page_size,omitempty"` // 每页数量,默认 30
OrderBy string `json:"orderby,omitempty"` // 排序字段create_time默认或 update_time
Desc bool `json:"desc,omitempty"` // 是否降序,默认 true
Keywords string `json:"keywords,omitempty"` // 关键词过滤(匹配文档标题)
Id string `json:"id,omitempty"` // 文档 ID 过滤
Name string `json:"name,omitempty"` // 文档名称过滤
CreateTimeFrom int64 `json:"create_time_from,omitempty"` // 创建时间起始Unix 时间戳0 表示无限制
CreateTimeTo int64 `json:"create_time_to,omitempty"` // 创建时间截止Unix 时间戳0 表示无限制
Suffix []string `json:"suffix,omitempty"` // 文件后缀过滤,如 ["pdf", "txt", "docx"]
Run []string `json:"run,omitempty"` // 处理状态过滤,支持 ["UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL"] 或数字格式 ["0", "1", "2", "3", "4"]
}
// ListDocumentsRes 列出文档响应
// 注意:响应结构与其他 List 接口不同data 是一个对象而非数组
type ListDocumentsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data struct {
Docs []*Document `json:"docs"` // 文档列表
TotalDatasets int `json:"total_datasets"` // 总文档数
} `json:"data"`
}
// DeleteDocumentsReq 删除文档请求
type DeleteDocumentsReq struct {
Ids []string `json:"ids"`
}
// ListDocuments 列出文档
func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListDocumentsReq) (*ListDocumentsRes, error) {
path := "/api/v1/datasets/" + datasetId + "/documents"
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Keywords != "" {
params["keywords"] = req.Keywords
}
if req.Id != "" {
params["id"] = req.Id
}
if req.Name != "" {
params["name"] = req.Name
}
if req.CreateTimeFrom > 0 {
params["create_time_from"] = req.CreateTimeFrom
}
if req.CreateTimeTo > 0 {
params["create_time_to"] = req.CreateTimeTo
}
// 构造查询字符串
query := buildQueryString(params)
var queryParts []string
if query != "" {
queryParts = append(queryParts, query)
}
// 处理数组参数suffix文件后缀过滤
// API 要求多个值时重复参数名suffix=pdf&suffix=txt
for _, suffix := range req.Suffix {
queryParts = append(queryParts, "suffix="+suffix)
}
// 处理数组参数run处理状态过滤
// 支持数字格式("0"-"4")或文本格式("UNSTART", "RUNNING", "CANCEL", "DONE", "FAIL"
for _, run := range req.Run {
queryParts = append(queryParts, "run="+run)
}
// 构造最终请求路径
if len(queryParts) > 0 {
path += "?" + strings.Join(queryParts, "&")
}
// 发送请求并处理响应
var res ListDocumentsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list documents failed: code=%d", res.Code)
}
return &res, nil
}
// UploadDocumentFromText 上传文本内容作为文档
func (c *Client) UploadDocumentFromText(ctx context.Context, datasetId, content, filename string) (documentId string, err error) {
if datasetId == "" {
return "", gerror.New("datasetId不能为空")
}
if content == "" {
return "", gerror.New("文档内容不能为空")
}
if filename == "" {
filename = "document.txt"
}
// 构造URL使用负载均衡
endpoint := c.getNextEndpoint()
if endpoint == "" {
return "", gerror.New("RAGFlow endpoints not configured")
}
url := endpoint + "/api/v1/datasets/" + datasetId + "/documents"
// 创建multipart writer
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// 添加文件字段
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", gerror.Wrap(err, "创建form file失败")
}
// 写入内容
if _, err = part.Write([]byte(content)); err != nil {
return "", gerror.Wrap(err, "写入文件内容失败")
}
// 关闭multipart writer
if err = writer.Close(); err != nil {
return "", gerror.Wrap(err, "关闭multipart writer失败")
}
// 发送请求
client := commonHttp.Httpclient.Clone()
client.SetHeader("Authorization", "Bearer "+c.APIKey)
client.SetHeader("Content-Type", writer.FormDataContentType())
resp, err := client.Post(ctx, url, body.Bytes())
if err != nil {
return "", gerror.Wrap(err, "上传文档请求失败")
}
defer resp.Close()
// 解析响应
var response struct {
Code int `json:"code"`
Message string `json:"message"`
Data []UploadDocumentRes `json:"data"` // RAGFlow返回数组
}
respBody := resp.ReadAll()
if err := json.Unmarshal(respBody, &response); err != nil {
g.Log().Errorf(ctx, "解析RAGFlow响应失败: %v, 原始响应: %s", err, string(respBody))
return "", gerror.Newf("json Decode failed: %v", err)
}
// 先检查code再检查data
if response.Code != 0 {
g.Log().Errorf(ctx, "RAGFlow返回错误: code=%d, message=%s", response.Code, response.Message)
return "", gerror.Newf("上传文档失败 (code=%d): %s", response.Code, response.Message)
}
if len(response.Data) == 0 {
g.Log().Errorf(ctx, "RAGFlow返回data为空, 完整响应: %s", string(respBody))
return "", gerror.New("上传文档返回data为空")
}
return response.Data[0].Id, nil
}
// UploadDocument 上传文档(保留兼容)
func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) (err error) {
return gerror.New("upload document from file not implemented yet, use UploadDocumentFromText instead")
}
// ParseDocumentsReq 解析文档请求
type ParseDocumentsReq struct {
DocumentIds []string `json:"document_ids"` // 要解析的文档ID列表
}
// ParseDocuments 解析文档(上传后必须调用此接口才会开始解析)
func (c *Client) ParseDocuments(ctx context.Context, datasetId string, documentIds []string) error {
if datasetId == "" {
return gerror.New("datasetId不能为空")
}
if len(documentIds) == 0 {
return gerror.New("documentIds不能为空")
}
req := ParseDocumentsReq{DocumentIds: documentIds}
var res CommonResponse
path := "/api/v1/datasets/" + datasetId + "/chunks"
if err := c.request(ctx, "POST", path, req, &res); err != nil {
return err
}
if !res.IsSuccess() {
return gerror.Newf("解析文档失败: %s", res.Message)
}
return nil
}
// DeleteDocument 删除文档
func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) (err error) {
req := DeleteDocumentsReq{Ids: ids}
var res CommonResponse
path := "/api/v1/datasets/" + datasetId + "/documents"
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete document failed: %s", res.Message)
}
return
}

View File

@@ -1,117 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gerror"
)
// OpenAICompatibleAPI 与 OpenAI 兼容的 API
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#与-openai-兼容的-api
// ChatCompletionMessage OpenAI 格式的消息
type ChatCompletionMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// ChatCompletionRequest OpenAI 格式的聊天补全请求
type ChatCompletionRequest struct {
Model string `json:"model"` // 模型名称(服务器会自动解析,可设置为任意值)
Messages []ChatCompletionMessage `json:"messages"` // 消息列表,必须至少包含一条 user 消息
Stream bool `json:"stream,omitempty"` // 是否流式返回,默认 false
}
// ChatCompletionResponse OpenAI 格式的聊天补全响应(非流式)
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatCompletionChunk 流式响应块
type ChatCompletionChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
Role string `json:"role"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage,omitempty"`
}
// CreateChatCompletion 创建聊天补全(与聊天助手)
// POST /api/v1/chats_openai/{chat_id}/chat/completions
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := "/api/v1/chats_openai/" + chatID + "/chat/completions"
var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, gerror.Newf("create chat completion failed: %v", err)
}
return &resp, nil
}
// CreateAgentCompletion 创建 Agent 补全
// POST /api/v1/agents_openai/{agent_id}/chat/completions
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := "/api/v1/agents_openai/" + agentID + "/chat/completions"
var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, gerror.Newf("create agent completion failed: %v", err)
}
return &resp, nil
}
// CreateChatCompletionStream 创建流式聊天补全(与聊天助手)
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
req.Stream = true
// TODO: 实现流式读取逻辑
return nil, gerror.New("stream mode not implemented yet")
}
// StreamReader 流式响应读取器
type StreamReader struct {
_ *gjson.Json // TODO: 实现流式读取时使用
close func() error
}
// ReadChunk 读取下一个响应块
// TODO: 实现流式读取逻辑
func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) {
return nil, gerror.New("stream mode not implemented yet")
}
// Close 关闭流
func (sr *StreamReader) Close() (err error) {
if sr.close != nil {
return sr.close()
}
return
}

View File

@@ -1,178 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
// 会话管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#会话管理
// Session 会话结构体
type Session struct {
Id string `json:"id"`
Name string `json:"name"`
ChatId string `json:"chat_id"` // 响应中是 "chat" 或 "chat_id",根据文档示例调整
Messages []Message `json:"messages"`
CreateDate string `json:"create_date"`
CreateTime int64 `json:"create_time"`
UpdateDate string `json:"update_date"`
UpdateTime int64 `json:"update_time"`
}
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
// CreateSessionReq 创建会话请求
type CreateSessionReq struct {
Name string `json:"name"`
UserId string `json:"user_id,omitempty"`
}
// ListSessionsReq 列出会话请求
type ListSessionsReq struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OrderBy string `json:"orderby,omitempty"`
Desc bool `json:"desc,omitempty"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
UserId string `json:"user_id,omitempty"`
}
// ListSessionsRes 列出会话响应
// 注意API 不返回 total 字段,仅返回 data 数组
type ListSessionsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data []*Session `json:"data"` // 会话列表
}
// DeleteSessionsReq 删除会话请求
type DeleteSessionsReq struct {
Ids []string `json:"ids"`
}
// ChatCompletionReq 对话请求
type ChatCompletionReq struct {
Question string `json:"question"`
Stream bool `json:"stream"`
SessionId string `json:"session_id,omitempty"`
UserId string `json:"user_id,omitempty"`
}
// ChatCompletionRes 对话响应 (非流式)
type ChatCompletionRes struct {
Code int `json:"code"`
Message string `json:"message"` // 错误信息
Data struct {
Answer string `json:"answer"`
Reference interface{} `json:"reference"`
AudioBinary interface{} `json:"audio_binary"`
Id interface{} `json:"id"`
SessionId string `json:"session_id"`
} `json:"data"`
}
// CreateSession 创建会话
func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) {
path := "/api/v1/chats/" + chatId + "/sessions"
var res struct {
Code int `json:"code"`
Data *Session `json:"data"`
Msg string `json:"message"`
}
if err := c.request(ctx, "POST", path, req, &res); err != nil {
g.Log().Errorf(ctx, "❌ CreateSession请求失败: chatId=%s, req=%+v, error=%v", chatId, req, err)
return nil, err
}
if res.Code != 0 {
g.Log().Errorf(ctx, "❌ CreateSession返回失败: chatId=%s, req=%+v, code=%d, msg=%s", chatId, req, res.Code, res.Msg)
return nil, gerror.Newf("create session failed: %s", res.Msg)
}
// 检查响应数据是否为空防止RAGFlow API返回 {"code":0, "data":null}
// 如果不检查直接返回,调用方会收到 (nil, nil),导致空指针异常
if res.Data == nil {
return nil, gerror.Newf("create session returned null data: %s", res.Msg)
}
return res.Data, nil
}
// ListSessions 列出会话
func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) {
path := "/api/v1/chats/" + chatId + "/sessions"
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Name != "" {
params["name"] = req.Name
}
if req.Id != "" {
params["id"] = req.Id
}
if req.UserId != "" {
params["user_id"] = req.UserId
}
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
var res ListSessionsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list sessions failed: code=%d", res.Code)
}
return &res, nil
}
// DeleteSessions 删除会话
func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) (err error) {
req := DeleteSessionsReq{Ids: ids}
var res CommonResponse
path := "/api/v1/chats/" + chatId + "/sessions"
if err = c.request(ctx, "DELETE", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete sessions failed: %s", res.Message)
}
return
}
// ChatCompletion 对话 (目前仅支持非流式)
func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) {
path := "/api/v1/chats/" + chatId + "/completions"
var res ChatCompletionRes
// 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式
if req.Stream {
return nil, gerror.New("stream mode not supported yet")
}
if err := c.request(ctx, "POST", path, req, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("chat completion failed: code=%d, message=%s", res.Code, res.Message)
}
return &res, nil
}

View File

@@ -1,39 +0,0 @@
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
)
// System 系统管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#系统
// HealthStatus 健康状态
type HealthStatus struct {
DB string `json:"db"` // "ok" 或 "nok"
Redis string `json:"redis"` // "ok" 或 "nok"
DocEngine string `json:"doc_engine"` // "ok" 或 "nok"
Storage string `json:"storage"` // "ok" 或 "nok"
Status string `json:"status"` // 整体状态: "ok" 或 "nok"
Meta map[string]interface{} `json:"_meta,omitempty"` // 详细错误信息
}
// CheckHealth 检查系统健康状况
// GET /v1/system/healthz
func (c *Client) CheckHealth(ctx context.Context) (*HealthStatus, error) {
var status HealthStatus
if err := c.request(ctx, "GET", "/v1/system/healthz", nil, &status); err != nil {
return nil, gerror.Newf("check health failed: %v", err)
}
return &status, nil
}
// IsHealthy 检查系统是否健康
func (c *Client) IsHealthy(ctx context.Context) (bool, error) {
status, err := c.CheckHealth(ctx)
if err != nil {
return false, err
}
return status.Status == "ok", nil
}

View File

@@ -1,154 +0,0 @@
package ragflow
import (
"context"
"runtime/debug"
"strings"
"sync"
"time"
"gitea.com/red-future/common/redis"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
)
// 默认批量大小(每次从 Redis 读取并发送的消息数)
const defaultBatchSize = 200
// QueueProcessor Stream 处理器,批量读取消息并发送到 RAGFlow
type QueueProcessor struct {
streamKey string // Stream 键名
groupName string // 消费者组名称
consumerName string // 消费者名称
timeout int64 // 阻塞超时时间(毫秒)
batchSize int64 // 最大并发数(协程池大小)
stopChan chan struct{} // 停止信号
pool *grpool.Pool // GoFrame协程池
handleFunc func(ctx context.Context, message map[string]interface{}) error
processingMsgs sync.Map // 正在处理的消息ID去重用
}
// NewQueueProcessor 创建 Stream 处理器
func NewQueueProcessor(streamKey, groupName, consumerName string, timeout, batchSize int64, handleFunc func(ctx context.Context, message map[string]interface{}) error) *QueueProcessor {
// 创建协程池固定大小避免频繁创建销毁goroutine
pool := grpool.New(int(batchSize))
return &QueueProcessor{
streamKey: streamKey,
groupName: groupName,
consumerName: consumerName,
timeout: timeout,
batchSize: batchSize,
stopChan: make(chan struct{}),
pool: pool, // 使用GoFrame协程池
handleFunc: handleFunc,
}
}
// Start 启动 Stream 处理器
// 削峰填谷:每次读取 batchSize 条消息,并发发送,发完立刻读下一批
func (q *QueueProcessor) Start(ctx context.Context) error {
glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 批量大小: %d",
q.streamKey, q.groupName, q.consumerName, q.batchSize)
// 确保 Consumer Group 存在(重试直到成功)
for {
if err := redis.CreateConsumerGroup(ctx, q.streamKey, q.groupName); err != nil {
// BUSYGROUP 表示已存在,不是错误
if strings.Contains(err.Error(), "BUSYGROUP") {
glog.Debugf(ctx, "Consumer Group 已存在")
break
}
glog.Warningf(ctx, "创建 Consumer Group 失败: %v1秒后重试", err)
time.Sleep(time.Second)
continue
}
glog.Infof(ctx, "Consumer Group 创建成功")
break
}
for {
select {
case <-q.stopChan:
glog.Info(ctx, "Stream 处理器收到停止信号")
return nil
default:
// 1. 从 Redis Stream 读取一批消息
messages, err := redis.ReadFromStream(ctx, q.streamKey, q.groupName, q.consumerName, q.batchSize, q.timeout)
if err != nil {
glog.Errorf(ctx, "从 Stream 读取消息失败: %v", err)
continue
}
if len(messages) == 0 {
continue
}
glog.Infof(ctx, "✅ 从Stream读取到 %d 条消息,开始处理", len(messages))
// 2. 去重+立即ACK对话场景优先实时性失败不重试
for i, msg := range messages {
m := msg // 捕获循环变量
msgIndex := i + 1
// 去重:如果消息正在处理,跳过
if _, exists := q.processingMsgs.LoadOrStore(m.ID, true); exists {
glog.Debugf(ctx, "⏭️ 跳过正在处理的消息 - ID: %s", m.ID)
continue
}
// 立即ACK对话场景不需要重试避免重复消费
if err := redis.AckMessage(ctx, q.streamKey, q.groupName, m.ID); err != nil {
glog.Errorf(ctx, "确认消息失败: %v, 消息ID: %s", err, m.ID)
}
glog.Infof(ctx, "📨 准备处理第 %d/%d 条消息 - ID: %s", msgIndex, len(messages), m.ID)
// 提交到协程池池满时会阻塞等待空闲worker
q.pool.Add(ctx, func(ctx context.Context) {
defer q.processingMsgs.Delete(m.ID) // 处理完成后移除标记
q.processMessage(ctx, m)
})
}
// 3. 立刻读下一批(不等待,协程池自动控制并发数)
}
}
}
// processMessage 处理单条消息(异步执行)
func (q *QueueProcessor) processMessage(ctx context.Context, message redis.StreamMessage) {
// 捕获panic防止协程崩溃
defer func() {
if r := recover(); r != nil {
glog.Errorf(ctx, "❌ PANIC: 消息处理发生panic - 消息ID: %s, panic内容: %v\n堆栈:\n%s",
message.ID, r, debug.Stack())
}
}()
glog.Infof(ctx, "🔄 开始处理消息 - ID: %s", message.ID)
// 打印实际字段名(调试用)
var fieldNames []string
for key := range message.Values {
fieldNames = append(fieldNames, key)
}
glog.Infof(ctx, "📋 消息字段名列表: %v", fieldNames)
glog.Infof(ctx, "📦 消息完整内容: %+v", message.Values)
// 调用处理函数发送到 RAGFlow
if err := q.handleFunc(ctx, message.Values); err != nil {
glog.Errorf(ctx, "❌ 消息处理失败: %v, 消息ID: %s", err, message.ID)
} else {
glog.Infof(ctx, "✅ 消息处理成功 - ID: %s", message.ID)
}
// ACK已在读取后立即执行此处无需重复ACK
// 对话场景:失败直接丢弃,不重试(实时性优先)
}
// Stop 停止队列处理器
func (q *QueueProcessor) Stop() {
close(q.stopChan)
// 关闭协程池,等待所有任务完成
q.pool.Close()
}

View File

@@ -1,19 +0,0 @@
package redis
// Redis 数据缓存 Key 常量
const (
CleanList = "list:tenantId-%v:collection-%s:*" // 清理列表Key
CleanCount = "count:tenantId-%v:collection-%s:*" // 清理计数Key
List = "list:tenantId-%v:collection-%s:filter:%s:options:%s" // 列表查询Key
Count = "count:tenantId-%v:collection-%s:filter:%s" // 计数查询Key
One = "one:tenantId-%v:collection-%s:filter:%s" // 单条查询Key
)
// 限流 Redis Key 常量
const (
RateLimitKeyPrefix = "ragflow:ratelimit:" // 限流Key前缀
RateLimitKeyIP = "ip:%s" // IP限流: ip:192.168.1.1
RateLimitKeyUser = "user:%s" // 用户限流: user:123 或 user:anon:192.168.1.1
RateLimitKeyService = "service:%s" // 服务限流: service:customerService
RateLimitKeyGlobal = "global:requests" // 全局限流: global:requests
)

View File

@@ -1,13 +0,0 @@
package redis
import "context"
type QueueMessage struct {
StreamKey string // Stream 键名
GroupName string // 消费者组名称
ConsumerName string // 消费者名称
BatchSize int64 // 最大并发数(信号量容量)
BlockMs int64 // 阻塞时间
AutoAck bool // ACK确认,true自动确认,false手动确认
HandleFunc func(ctx context.Context, message map[string]interface{}) error
}

View File

@@ -1,743 +0,0 @@
package redis
import (
"context"
"strings"
"sync"
"time"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var (
// redisClient 内部使用的 Redis 客户端(单例模式)
redisClient *gredis.Redis
redisOnce sync.Once
)
// getClient 获取 Redis 客户端(延迟初始化)
func getClient() *gredis.Redis {
redisOnce.Do(func() {
redisClient = g.Redis()
})
return redisClient
}
// getClient 获取 Redis 客户端 临时方法
func GetRedisClientTest(name string) *gredis.Redis {
return g.Redis(name)
}
// RedisClient 获取 Redis 客户端(函数式,确保单例正确初始化)
func RedisClient() *gredis.Redis {
return getClient()
}
func GetReadStream(ctx context.Context, msg ...QueueMessage) error {
for _, t := range msg {
err := GetReadFromStream(ctx, t.StreamKey, t.GroupName, t.ConsumerName, t.BatchSize, t.BlockMs, t.AutoAck, t.HandleFunc)
if err != nil {
glog.Infof(ctx, "读取ReadFromStream数据失败-> 键名: %s, 消费者组: %s, 消费者名称%v\n, 失败err:%v\n", t.StreamKey, t.GroupName, t.ConsumerName, err)
continue
}
}
return nil
}
// GetReadFromStream 读取ReadFromStream数据
func GetReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count, blockMs int64, autoAck bool, fn func(ctx context.Context, message map[string]interface{}) error) (err error) {
glog.Infof(ctx, "初始化 Stream: %s, 消费者组: %s", streamKey, groupName)
err = InitStreamGroup(ctx, streamKey, groupName)
if err != nil {
return err
}
for {
// 从 Redis Stream 读取一批消息
messages, err := ReadFromStream(ctx, streamKey, groupName, consumerName, count, blockMs)
if err != nil {
glog.Errorf(ctx, "[DEBUG Redis] XREADGROUP 错误: %v", err)
return err
}
// 处理消息
for _, msg := range messages {
glog.Infof(ctx, "消费者 '%s' -> 接收到消息 ID: %s, 内容: %v\n", consumerName, msg.ID, msg.Values)
// 业务处理
if err = fn(ctx, msg.Values); err != nil {
glog.Infof(ctx, "业务处理失败-> err:%v\n", err)
continue
}
// 确认消息 (ACK)
if autoAck {
// 处理成功后,必须调用 XAck否则消息会一直留在 PEL 中
err = AckMessage(ctx, streamKey, groupName, msg.ID)
if err != nil {
glog.Infof(ctx, "消费者 '%s' 确认消息 ID %s 失败: %v\n", consumerName, msg.ID, err)
} else {
glog.Infof(ctx, "消费者 '%s' -> 已确认消息 ID: %s\n", consumerName, msg.ID)
}
}
}
}
return
}
// Stream 和消费者组常量
const (
// RAGFlow 请求 Stream Key
RAGFlowRequestStreamKey = "ragflow:request:stream"
// RAGFlow 响应 Stream Key
RAGFlowResponseStreamKey = "ragflow:response:stream"
// RAGFlow 请求消费者组名称
RAGFlowRequestConsumerGroup = "ragflow:request:consumer:group"
// RAGFlow 响应消费者组名称
RAGFlowResponseConsumerGroup = "ragflow:response:consumer:group"
// RAGFlow 消费者组名称(兼容旧代码)
RAGFlowConsumerGroup = "ragflow:consumer:group"
// 会话最后活跃时间 Key 前缀
SessionLastActiveKeyPrefix = "ragflow:session:"
)
// StreamMessage Redis Stream 消息结构
type StreamMessage struct {
ID string // 消息ID自动生成
Values map[string]interface{} // 消息内容
}
// InitStreamGroup 初始化 Stream 和消费者组
// 使用 gredis Do() 方法执行 XGROUP CREATE 命令
func InitStreamGroup(ctx context.Context, streamKey, groupName string) error {
// XGROUP CREATE streamKey groupName 0 MKSTREAM
_, err := getClient().Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM")
if err != nil {
// 如果组已存在,忽略错误
errStr := err.Error()
if strings.Contains(errStr, "BUSYGROUP") || strings.Contains(errStr, "already exists") {
return nil
}
return err
}
return nil
}
// AddToStream 将消息添加到 Stream
// 使用 gredis Do() 方法执行 XADD 命令
// msg 可以是结构体或 map内部自动转换
func AddToStream(ctx context.Context, streamKey string, msg interface{}) (messageID string, err error) {
// 将结构体转换为 map
values := gconv.Map(msg)
// XADD streamKey * field1 value1 field2 value2 ...
args := make([]interface{}, 0, len(values)*2+2)
args = append(args, streamKey, "*") // "*" 自动生成ID
for key, val := range values {
args = append(args, key, val)
}
result, err := getClient().Do(ctx, "XADD", args...)
if err != nil {
return
}
messageID = result.String()
return
}
// CreateConsumerGroup 创建消费者组(如果不存在)
// XGROUP CREATE streamKey groupName 0 MKSTREAM
// 使用0作为起始ID从Stream开头读取所有未消费消息
func CreateConsumerGroup(ctx context.Context, streamKey, groupName string) error {
_, err := getClient().Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM")
return err
}
// ReadFromStream 从 Stream 读取消息(消费者组模式)
// 使用 gredis Do() 方法执行 XREADGROUP 命令
func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) {
// 检查是否需要记录trace避免轮询产生大量trace
execCtx := ctx
if !g.Cfg().MustGet(ctx, "jaeger.traceStream", true).Bool() {
// 不记录trace使用background context不继承span
execCtx = context.Background()
}
RECONNECT:
// 先尝试读取pending消息ID=0处理积压
result, err := getClient().Do(execCtx,
"XREADGROUP", "GROUP", groupName, consumerName,
"COUNT", count,
"BLOCK", 0, // 不阻塞,立即返回
"STREAMS", streamKey, "0", // ID=0 读取pending消息
)
if err != nil {
g.Log().Errorf(ctx, "❌ XREADGROUP读取pending失败: stream=%s, error=%v", streamKey, err)
time.Sleep(time.Second)
goto RECONNECT
}
// 检查pending结果是否为空需要检查消息数组是否为空
hasPending := false
if result != nil && !result.IsEmpty() {
// 尝试解析map格式
if resultVal := result.Val(); resultVal != nil {
if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok {
for _, streamMsgs := range streamsMap {
if msgsArray, ok := streamMsgs.([]interface{}); ok && len(msgsArray) > 0 {
hasPending = true
break
}
}
}
}
}
// 如果没有pending消息读取新消息
if !hasPending {
result, err = getClient().Do(execCtx,
"XREADGROUP", "GROUP", groupName, consumerName,
"COUNT", count,
"BLOCK", blockMs,
"STREAMS", streamKey, ">",
)
if err != nil {
g.Log().Errorf(ctx, "❌ XREADGROUP读取新消息失败: stream=%s, error=%v", streamKey, err)
time.Sleep(time.Second)
goto RECONNECT
}
}
// 预分配容量,避免动态扩容
messages := make([]StreamMessage, 0, int(count))
if result == nil || result.IsEmpty() {
// 超时或没有数据
return messages, nil
}
// GoFrame gredis 返回格式: map[streamKey:[[msgID [field1 value1 field2 value2 ...]] ...]]
resultVal := result.Val()
// 尝试 map 格式GoFrame gredis 返回)
if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok {
for streamKey, streamMsgs := range streamsMap {
msgsArray, ok := streamMsgs.([]interface{})
if !ok {
g.Log().Errorf(ctx, "❌ streamMsgs类型转换失败: streamKey=%v, 实际类型=%T", streamKey, streamMsgs)
continue
}
for i, msgData := range msgsArray {
msgArray, ok := msgData.([]interface{})
if !ok {
g.Log().Errorf(ctx, "❌ msgData类型转换失败: index=%d, 实际类型=%T", i, msgData)
continue
}
if len(msgArray) < 2 {
g.Log().Errorf(ctx, "❌ msgArray长度不足: index=%d, len=%d", i, len(msgArray))
continue
}
msgID := gconv.String(msgArray[0])
fieldsArray, ok := msgArray[1].([]interface{})
if !ok {
g.Log().Errorf(ctx, "❌ fieldsArray类型转换失败: msgID=%s, msgArray[1]类型=%T", msgID, msgArray[1])
continue
}
values := make(map[string]interface{}, len(fieldsArray)/2)
for i := 0; i < len(fieldsArray); i += 2 {
if i+1 < len(fieldsArray) {
key := gconv.String(fieldsArray[i])
values[key] = fieldsArray[i+1]
}
}
messages = append(messages, StreamMessage{
ID: msgID,
Values: values,
})
}
}
if len(messages) == 0 {
g.Log().Errorf(ctx, "❌ [ReadFromStream] map格式解析失败: streamsMap长度=%d, 但未提取到消息", len(streamsMap))
}
return messages, nil
}
// 尝试数组格式(标准 Redis 返回)
if streamsArray, ok := resultVal.([]interface{}); ok && len(streamsArray) > 0 {
for _, streamData := range streamsArray {
streamArray, ok := streamData.([]interface{})
if !ok || len(streamArray) < 2 {
continue
}
messagesArray, ok := streamArray[1].([]interface{})
if !ok {
continue
}
for _, msgData := range messagesArray {
msgArray, ok := msgData.([]interface{})
if !ok || len(msgArray) < 2 {
continue
}
msgID := gconv.String(msgArray[0])
fieldsArray, ok := msgArray[1].([]interface{})
if !ok {
continue
}
values := make(map[string]interface{}, len(fieldsArray)/2)
for i := 0; i < len(fieldsArray); i += 2 {
if i+1 < len(fieldsArray) {
key := gconv.String(fieldsArray[i])
values[key] = fieldsArray[i+1]
}
}
messages = append(messages, StreamMessage{
ID: msgID,
Values: values,
})
}
}
if len(messages) == 0 {
g.Log().Errorf(ctx, "❌ [ReadFromStream] 数组格式解析失败: streamsArray长度=%d, 但未提取到消息", len(streamsArray))
}
return messages, nil
}
g.Log().Errorf(ctx, "❌ [ReadFromStream] 无法识别的result格式, resultVal类型: %T, 值: %+v", resultVal, resultVal)
return messages, nil
}
// AckMessage 确认消息已处理
// 使用 gredis Do() 方法执行 XACK 命令
func AckMessage(ctx context.Context, streamKey, groupName string, messageIDs ...string) error {
// XACK streamKey groupName messageID1 messageID2 ...
// 预分配容量,避免动态扩容
args := make([]interface{}, 0, len(messageIDs)+2)
args = append(args, streamKey, groupName)
for _, id := range messageIDs {
args = append(args, id)
}
_, err := getClient().Do(ctx, "XACK", args...)
return err
}
// GetStreamLength 获取 Stream 当前长度
// 使用 gredis Do() 方法执行 XLEN 命令
func GetStreamLength(ctx context.Context, streamKey string) (int64, error) {
// XLEN streamKey
result, err := getClient().Do(ctx, "XLEN", streamKey)
if err != nil {
return 0, err
}
length := gconv.Int64(result)
return length, nil
}
// PendingMessage Pending 消息结构
type PendingMessage struct {
ID string // 消息ID
Consumer string // 消费者名称
Idle int64 // 空闲时间(毫秒)
RetryCount int64 // 重试次数
}
// GetPendingMessages 获取待处理消息
// 使用 gredis Do() 方法执行 XPENDING 命令
func GetPendingMessages(ctx context.Context, streamKey, groupName string, start, end string, count int64) ([]PendingMessage, error) {
// XPENDING streamKey groupName start end count
result, err := getClient().Do(ctx, "XPENDING", streamKey, groupName, start, end, count)
if err != nil {
return nil, err
}
if result == nil {
return nil, nil
}
// 解析返回值:[[ID, consumer, idle, retryCount], ...]
pendingArray, ok := result.Val().([]interface{})
if !ok {
return nil, nil
}
messages := make([]PendingMessage, 0, len(pendingArray))
for _, item := range pendingArray {
itemArray, ok := item.([]interface{})
if !ok || len(itemArray) < 4 {
continue
}
messages = append(messages, PendingMessage{
ID: gconv.String(itemArray[0]),
Consumer: gconv.String(itemArray[1]),
Idle: gconv.Int64(itemArray[2]),
RetryCount: gconv.Int64(itemArray[3]),
})
}
return messages, nil
}
// ClaimPendingMessage 认领超时的 Pending 消息
// 使用 gredis Do() 方法执行 XCLAIM 命令
func ClaimPendingMessage(ctx context.Context, streamKey, groupName, consumerName string, minIdleTime int64, messageIDs ...string) ([]StreamMessage, error) {
// XCLAIM streamKey groupName consumerName minIdleTime messageID1 messageID2 ...
args := []interface{}{streamKey, groupName, consumerName, minIdleTime}
for _, id := range messageIDs {
args = append(args, id)
}
result, err := getClient().Do(ctx, "XCLAIM", args...)
if err != nil {
return nil, err
}
if result == nil {
return nil, nil
}
// 解析返回值:类似 XREADGROUP
messagesArray, ok := result.Val().([]interface{})
if !ok {
return nil, nil
}
// 预分配容量,避免动态扩容
messages := make([]StreamMessage, 0, len(messagesArray))
for _, msgData := range messagesArray {
msgArray, ok := msgData.([]interface{})
if !ok || len(msgArray) < 2 {
continue
}
msgID := gconv.String(msgArray[0])
fieldsArray, ok := msgArray[1].([]interface{})
if !ok {
continue
}
// 预分配 map 容量 ,避免动态扩容
values := make(map[string]interface{}, len(fieldsArray)/2)
for i := 0; i < len(fieldsArray); i += 2 {
if i+1 < len(fieldsArray) {
key := gconv.String(fieldsArray[i])
values[key] = fieldsArray[i+1]
}
}
messages = append(messages, StreamMessage{
ID: msgID,
Values: values,
})
}
return messages, nil
}
// SetSessionLastActive 设置用户最后活跃时间
// 使用 gredis SetEX 方法
func SetSessionLastActive(ctx context.Context, userId string) error {
key := SessionLastActiveKeyPrefix + userId + ":last_active"
timestamp := gtime.Now().Timestamp()
// SETEX key 7200 value (7200秒 = 2小时)
_, err := getClient().Do(ctx, "SETEX", key, 7200, timestamp)
return err
}
// GetSessionLastActive 获取用户最后活跃时间
// 使用 gredis Get 方法
func GetSessionLastActive(ctx context.Context, userId string) (int64, error) {
key := SessionLastActiveKeyPrefix + userId + ":last_active"
result, err := getClient().Get(ctx, key)
if err != nil {
return 0, err
}
if result.IsEmpty() {
return 0, nil
}
timestamp := gconv.Int64(result.Val())
return timestamp, nil
}
// IsUserActive 检查用户是否在指定时间范围内活跃过
// 用于追问逻辑:如果用户最近活跃过,则不发送追问消息
// 参数:
// - userId: 用户ID
// - seconds: 时间范围例如传入300表示检查5分钟内是否活跃
//
// 返回:
// - bool: true表示用户在指定时间内活跃过
// - error: 操作失败时返回错误
func IsUserActive(ctx context.Context, userId string, seconds int64) (bool, error) {
lastActive, err := GetSessionLastActive(ctx, userId)
if err != nil {
return false, err
}
if lastActive == 0 {
return false, nil // 未找到记录,视为不活跃
}
// 检查时间差
now := gtime.Now().Timestamp()
return (now - lastActive) < seconds, nil
}
// ============== 限流相关 ==============
// IncrRateLimit 增加限流计数器,返回当前计数
// key: 限流key需要包含完整路径如 "ip:192.168.1.1"
// windowSeconds: 时间窗口(秒)
func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) {
fullKey := RateLimitKeyPrefix + key
result, err := getClient().Do(ctx, "INCR", fullKey)
if err != nil {
return
}
count = result.Int64()
// 首次设置过期时间
if count == 1 {
getClient().Do(ctx, "EXPIRE", fullKey, windowSeconds)
}
return
}
// GetRateLimit 获取当前限流计数
func GetRateLimit(ctx context.Context, key string) (count int64, err error) {
fullKey := RateLimitKeyPrefix + key
result, err := getClient().Get(ctx, fullKey)
if err != nil {
return
}
if result.IsEmpty() {
return 0, nil
}
count = result.Int64()
return
}
// SetSessionCache 缓存 RAGFlow Session ID租户+用户隔离)
func SetSessionCache(ctx context.Context, tenantId, userId, sessionId string) error {
key := SessionLastActiveKeyPrefix + tenantId + ":" + userId + ":session_id"
// SETEX key 7200 value (7200秒 = 2小时与last_active保持一致)
_, err := getClient().Do(ctx, "SETEX", key, 7200, sessionId)
return err
}
// GetSessionCache 获取缓存的 RAGFlow Session ID租户+用户隔离)
func GetSessionCache(ctx context.Context, tenantId, userId string) (string, error) {
key := SessionLastActiveKeyPrefix + tenantId + ":" + userId + ":session_id"
result, err := getClient().Get(ctx, key)
if err != nil {
return "", err
}
if result.IsEmpty() {
return "", nil
}
return result.String(), nil
}
// DelSessionCache 删除缓存的 RAGFlow Session ID归档时调用租户+用户隔离)
func DelSessionCache(ctx context.Context, tenantId, userId string) error {
key := SessionLastActiveKeyPrefix + tenantId + ":" + userId + ":session_id"
_, err := getClient().Del(ctx, key)
return err
}
// TryLock 尝试获取分布式锁(非阻塞)
// key: 锁的键名
// expireSeconds: 锁的过期时间(秒),防止死锁
// 返回 true 表示获取成功false 表示锁已被其他节点持有
func TryLock(ctx context.Context, key string, expireSeconds int) bool {
// SET key value NX EX expireSeconds
result, err := getClient().Do(ctx, "SET", key, gtime.Now().String(), "NX", "EX", expireSeconds)
if err != nil {
glog.Errorf(ctx, "获取分布式锁失败: %v", err)
return false
}
return result.String() == "OK"
}
// Unlock 释放分布式锁
func Unlock(ctx context.Context, key string) {
if _, err := getClient().Del(ctx, key); err != nil {
glog.Errorf(ctx, "释放分布式锁失败: %v", err)
}
}
// ============== 对话计数相关(用于卡片触发)==============
const (
// UserStateKeyPrefix 用户会话状态 Key 前缀(融合阶段+计数)
UserStateKeyPrefix = "ragflow:user:state:"
// UserStateExpireSeconds 用户状态过期时间5分钟
UserStateExpireSeconds = 300
)
// UserState 用户会话状态(阶段+对话计数+咨询方向,统一5分钟过期
type UserState struct {
Stage int `json:"stage"` // 当前阶段
Direction string `json:"direction"` // 咨询方向
Count int64 `json:"count"` // 对话计数v5.2卡片触发)
AccountName string `json:"accountName"` // 用户选择的方向对应的客服账号名称
}
// GetUserState 获取用户状态(阶段+计数)
func GetUserState(ctx context.Context, userId, platform string) (state *UserState, err error) {
key := UserStateKeyPrefix + userId + "_" + platform
result, err := getClient().Do(ctx, "HGETALL", key)
if err != nil {
return
}
state = &UserState{Stage: 5} // 默认状态5未选择方向
if result.IsEmpty() {
// Redis为空初始化默认状态
if initErr := SetUserStage(ctx, userId, platform, 5); initErr != nil {
err = initErr
return
}
return
}
m := result.Map()
state.Stage = gconv.Int(m["stage"])
state.Count = gconv.Int64(m["count"])
state.Direction = gconv.String(m["direction"])
return
}
// SetUserStage 设置用户阶段,并刷新过期时间
func SetUserStage(ctx context.Context, userId, platform string, stage int) error {
key := UserStateKeyPrefix + userId + "_" + platform
_, err := getClient().Do(ctx, "HSET", key, "stage", stage)
if err != nil {
return err
}
_, err = getClient().Do(ctx, "EXPIRE", key, UserStateExpireSeconds)
return err
}
// SetUserAccountName 设置用户对应的客服账号名称,并刷新过期时间
func SetUserAccountName(ctx context.Context, userId, platform, accountName string) error {
key := UserStateKeyPrefix + userId + "_" + platform
_, err := getClient().Do(ctx, "HSET", key, "accountName", accountName)
if err != nil {
return err
}
_, err = getClient().Do(ctx, "EXPIRE", key, UserStateExpireSeconds)
return err
}
// SetUserDirection 设置用户选择的咨询方向,并刷新过期时间
func SetUserDirection(ctx context.Context, userId, platform, direction string) error {
key := UserStateKeyPrefix + userId + "_" + platform
_, err := getClient().Do(ctx, "HSET", key, "direction", direction)
if err != nil {
return err
}
_, err = getClient().Do(ctx, "EXPIRE", key, UserStateExpireSeconds)
return err
}
// IncrUserCount 增加用户对话计数,返回当前轮数,并刷新过期时间
func IncrUserCount(ctx context.Context, userId, platform string) (count int64, err error) {
key := UserStateKeyPrefix + userId + "_" + platform
result, err := getClient().Do(ctx, "HINCRBY", key, "count", 1)
if err != nil {
return
}
count = result.Int64()
_, err = getClient().Do(ctx, "EXPIRE", key, UserStateExpireSeconds)
return
}
// ResetUserState 重置用户状态(归档时调用)
func ResetUserState(ctx context.Context, userId, platform string) error {
key := UserStateKeyPrefix + userId + "_" + platform
_, err := getClient().Del(ctx, key)
return err
}
// ============== 对话缓存相关5句落库==============
const (
// ConversationCacheKeyPrefix 对话缓存 Key 前缀
ConversationCacheKeyPrefix = "ragflow:conversation:cache:"
// ConversationCacheExpireSeconds 对话缓存过期时间10分钟
ConversationCacheExpireSeconds = 600
)
// CacheConversation 缓存单条对话到Redis List按sessionId存储
func CacheConversation(ctx context.Context, sessionId string, data []byte) error {
key := ConversationCacheKeyPrefix + sessionId
_, err := getClient().Do(ctx, "RPUSH", key, string(data))
if err != nil {
return err
}
_, err = getClient().Do(ctx, "EXPIRE", key, ConversationCacheExpireSeconds)
return err
}
// GetCachedConversations 获取缓存的对话列表并清空按sessionId查询
func GetCachedConversations(ctx context.Context, sessionId string) (list []string, err error) {
key := ConversationCacheKeyPrefix + sessionId
result, err := getClient().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return
}
if result.IsEmpty() {
return
}
list = result.Strings()
// 清空缓存
getClient().Del(ctx, key)
return
}
// GetCachedConversationCount 获取缓存的对话数量按sessionId查询
func GetCachedConversationCount(ctx context.Context, sessionId string) (count int64, err error) {
key := ConversationCacheKeyPrefix + sessionId
result, err := getClient().Do(ctx, "LLEN", key)
if err != nil {
return
}
return result.Int64(), nil
}
// ClearCachedConversations 清空对话缓存归档时调用按sessionId
func ClearCachedConversations(ctx context.Context, sessionId string) error {
key := ConversationCacheKeyPrefix + sessionId
_, err := getClient().Del(ctx, key)
return err
}
// ========== 以下为兼容旧接口(内部调用新实现)==========
// IncrConversationCount 增加用户对话计数(兼容旧接口)
func IncrConversationCount(ctx context.Context, userId, platform string, _ int64) (count int64, err error) {
return IncrUserCount(ctx, userId, platform)
}
// GetConversationCount 获取用户当前对话轮数(兼容旧接口)
func GetConversationCount(ctx context.Context, userId, platform string) (count int64, err error) {
state, err := GetUserState(ctx, userId, platform)
if err != nil {
return
}
return state.Count, nil
}
// ResetConversationCount 重置用户对话计数(兼容旧接口)
func ResetConversationCount(ctx context.Context, userId, platform string) error {
return ResetUserState(ctx, userId, platform)
}

View File

@@ -1,129 +0,0 @@
package redis
import (
"context"
"github.com/gogf/gf/v2/frame/g"
)
// HistoryMessage 历史消息结构(用于上下文注入)
type HistoryMessage struct {
Question string `json:"question"` // 用户问题
Answer string `json:"answer"` // AI 回复
}
// SendStreamMessage 发送到 Redis Stream 的消息结构
type SendStreamMessage struct {
UserId string `json:"userId"` // 用户ID
Content string `json:"content"` // 消息内容
Timestamp int64 `json:"timestamp"` // 时间戳(秒)
MessageId string `json:"messageId"` // 消息唯一ID
Platform string `json:"platform,omitempty"` // 平台标识
AccountId string `json:"accountId,omitempty"` // 账号ID
TenantId string `json:"tenantId,omitempty"` // 租户ID数据隔离
AccountName string `json:"accountName,omitempty"` // 客服账号名称
ChatId string `json:"chatId,omitempty"` // RAGFlow Chat ID从ragflow_config查询
ReplyQueue string `json:"replyQueue,omitempty"` // 响应队列名称(支持多实例独立队列)
History []HistoryMessage `json:"history,omitempty"` // 历史对话(归档后恢复时携带)
}
// BatchStreamMessage 批量消息结构
type BatchStreamMessage struct {
UserId string `json:"userId"` // 用户ID
Content string `json:"content"` // 消息内容
Timestamp int64 `json:"timestamp"` // 时间戳(秒)
BatchId string `json:"batchId"` // 批次ID
Index int `json:"index"` // 批次内序号
}
// ResponseStreamMessage RAGFlow 响应消息结构MQ 消息)
type ResponseStreamMessage struct {
UserId string `json:"userId"` // 用户ID
Platform string `json:"platform"` // 平台标识
TenantId string `json:"tenantId"` // 租户ID
AccountId string `json:"accountId,omitempty"` // 账号ID
AccountName string `json:"accountName,omitempty"` // 客服账号名称
Question string `json:"question"` // 用户问题
Content string `json:"content"` // RAGFlow 回复内容
SessionId string `json:"sessionId"` // RAGFlow Session ID
Timestamp int64 `json:"timestamp"` // 时间戳(秒)
MessageId string `json:"messageId"` // 原始消息ID
}
// FollowUpMessage 追问消息结构RabbitMQ 延时队列)
type FollowUpMessage struct {
TenantId string `json:"tenantId"` // 租户ID
UserId string `json:"userId"` // 用户ID
Platform string `json:"platform"` // 平台标识
Content string `json:"content"` // 追问内容
FollowUpType int `json:"followUpType"` // 追问类型1=30s, 2=60s, 3=180s
Timestamp int64 `json:"timestamp"` // 发送时间戳
}
// 追问类型常量
const (
FollowUpType1 = 1 // 第一次追问
FollowUpType2 = 2 // 第二次追问
FollowUpType3 = 3 // 第三次追问
)
// GetFollowUpContent 获取追问话术(从 config.yml 读取)
func GetFollowUpContent(followUpType int) string {
ctx := context.Background()
contents := g.Cfg().MustGet(ctx, "followUp.contents").Strings()
if len(contents) == 0 {
return ""
}
// followUpType: 1,2,3 对应数组索引 0,1,2
index := followUpType - 1
if index >= 0 && index < len(contents) {
return contents[index]
}
return ""
}
// GetFollowUpDelay 获取追问延时(从 config.yml 读取)
func GetFollowUpDelay(followUpType int) int {
ctx := context.Background()
delays := g.Cfg().MustGet(ctx, "followUp.delays").Ints()
if len(delays) == 0 {
return 30 // 默认30秒
}
// followUpType: 1,2,3 对应数组索引 0,1,2
index := followUpType - 1
if index >= 0 && index < len(delays) {
return delays[index]
}
return 30
}
// ArchiveMessage 会话归档消息结构RabbitMQ 延时队列)
type ArchiveMessage struct {
UserId string `json:"userId"` // 用户ID
Platform string `json:"platform"` // 平台标识
SessionId string `json:"sessionId"` // RAGFlow Session ID
TenantId string `json:"tenantId"` // 租户ID
Timestamp int64 `json:"timestamp"` // 发送时间戳
}
// GetArchiveDelay 获取归档延时(从 config.yml 读取)
func GetArchiveDelay() int {
ctx := context.Background()
return g.Cfg().MustGet(ctx, "archive.delay", 3600).Int() // 默认3600秒1小时
}
// GetHistoryContextLimit 获取历史上下文轮数(从 config.yml 读取)
func GetHistoryContextLimit() int64 {
ctx := context.Background()
return g.Cfg().MustGet(ctx, "history.contextLimit", 5).Int64() // 默认5轮对话
}
// DocSyncMessage 文档同步消息结构RAGFlow与MongoDB同步
type DocSyncMessage struct {
DocId string `json:"docId"` // MongoDB文档ID
RagflowDocId string `json:"ragflowDocId"` // RAGFlow文档ID
TenantId string `json:"tenantId"` // 租户ID
DocType string `json:"docType"` // 文档类型speechcraft/product
Action string `json:"action"` // 操作类型sync_ragflow_id
Timestamp int64 `json:"timestamp"` // 时间戳
}