Files
common/ragflow/worker_pool.go
2026-03-12 08:51:00 +08:00

191 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ragflow
import (
"context"
"gitee.com/red-future---jilin-g/common/redis"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
)
// 默认协程池大小
const defaultPoolSize = 200
// workerPool 协程池单例grpool.New 是原型模式,需要变量引用)
var workerPool = grpool.New(defaultPoolSize)
// WorkerPool RAGFlow 请求处理协程池(封装 grpool
type WorkerPool struct {
pool *grpool.Pool
size int
}
// Pool 协程池单例实例(直接引用使用)
var Pool = &WorkerPool{
pool: workerPool,
size: defaultPoolSize,
}
// Submit 提交任务到协程池
// 参数:
// - ctx: 上下文
// - task: 要执行的任务函数
//
// 返回error 提交失败时返回错误
func (w *WorkerPool) Submit(ctx context.Context, task func(ctx context.Context)) error {
return w.pool.Add(ctx, func(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
glog.Errorf(ctx, "协程池任务执行 panic: %v", r)
}
}()
task(ctx)
})
}
// Size 获取协程池大小
func (w *WorkerPool) Size() int {
return w.size
}
// Jobs 获取当前等待执行的任务数量
func (w *WorkerPool) Jobs() int {
return w.pool.Jobs()
}
// Close 关闭协程池
func (w *WorkerPool) Close() {
w.pool.Close()
}
// WorkerStats 协程池统计信息
type WorkerStats struct {
PoolSize int // 协程池大小
Jobs int // 等待执行的任务数
}
// Stats 获取协程池统计信息
func (w *WorkerPool) Stats() WorkerStats {
return WorkerStats{
PoolSize: w.size,
Jobs: w.pool.Jobs(),
}
}
// PrintStats 打印协程池统计信息
func (w *WorkerPool) PrintStats(ctx context.Context) {
stats := w.Stats()
glog.Infof(ctx, "协程池统计 - 池大小: %d, 等待任务: %d", stats.PoolSize, stats.Jobs)
}
// QueueProcessor Stream 处理器,从 Redis Stream 中取出任务并提交到协程池
type QueueProcessor struct {
pool *WorkerPool
streamKey string // Stream 键名
groupName string // 消费者组名称
consumerName string // 消费者名称
timeout int64 // 阻塞超时时间(毫秒)
batchSize int64 // 每次读取的消息数量
stopChan chan struct{}
handleFunc func(ctx context.Context, message map[string]interface{}) error
}
// NewQueueProcessor 创建 Stream 处理器
// 参数:
// - pool: 协程池
// - streamKey: Redis Stream 键名
// - groupName: 消费者组名称
// - consumerName: 消费者名称(唯一标识)
// - timeout: 从 Stream 取消息的超时时间(毫秒)
// - batchSize: 每次读取的消息数量
// - handleFunc: 消息处理函数
func NewQueueProcessor(pool *WorkerPool, streamKey, groupName, consumerName string, timeout int64, batchSize int64, handleFunc func(ctx context.Context, message map[string]interface{}) error) *QueueProcessor {
return &QueueProcessor{
pool: pool,
streamKey: streamKey,
groupName: groupName,
consumerName: consumerName,
timeout: timeout,
batchSize: batchSize,
stopChan: make(chan struct{}),
handleFunc: handleFunc,
}
}
// Start 启动 Stream 处理器
// 会阻塞运行,持续从 Redis Stream 中取出消息并提交到协程池处理
func (q *QueueProcessor) Start(ctx context.Context) error {
glog.Infof(ctx, "Stream 处理器启动 - Stream: %s, 消费者组: %s, 消费者: %s, 超时: %dms",
q.streamKey, q.groupName, q.consumerName, q.timeout)
loopCount := 0
for {
select {
case <-q.stopChan:
glog.Info(ctx, "Stream 处理器收到停止信号")
return nil
default:
loopCount++
if loopCount%10 == 1 {
glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,准备读取消息...", loopCount)
}
// 从 Redis Stream 中读取消息
messages, err := q.fetchMessages(ctx)
if err != nil {
glog.Errorf(ctx, "从 Stream 读取消息失败: %v", err)
continue
}
// 没有新消息,继续等待
if len(messages) == 0 {
if loopCount%10 == 1 {
glog.Debugf(ctx, "[DEBUG] 第 %d 次循环,无新消息", loopCount)
}
continue
}
glog.Infof(ctx, "[DEBUG] 收到 %d 条消息", len(messages))
// 处理每条消息
for _, msg := range messages {
glog.Infof(ctx, "[DEBUG] 处理消息 ID: %s, Values: %+v", msg.ID, msg.Values)
// 提交到协程池处理
if err := q.submitTask(ctx, msg); err != nil {
glog.Errorf(ctx, "提交任务到协程池失败: %v, 消息ID: %s", err, msg.ID)
}
}
}
}
}
// Stop 停止队列处理器
func (q *QueueProcessor) Stop() {
close(q.stopChan)
}
// fetchMessages 从 Redis Stream 中读取消息
func (q *QueueProcessor) fetchMessages(ctx context.Context) ([]redis.StreamMessage, error) {
// 从消费者组读取消息
return redis.ReadFromStream(ctx, q.streamKey, q.groupName, q.consumerName, q.batchSize, q.timeout)
}
// submitTask 将消息处理任务提交到协程池
func (q *QueueProcessor) submitTask(ctx context.Context, message redis.StreamMessage) error {
return q.pool.Submit(ctx, func(ctx context.Context) {
// 处理消息
if err := q.handleFunc(ctx, message.Values); err != nil {
glog.Errorf(ctx, "处理消息失败: %v, 消息ID: %s", err, message.ID)
return
}
// 处理成功后确认消息
if err := redis.AckMessage(ctx, q.streamKey, q.groupName, message.ID); err != nil {
glog.Errorf(ctx, "确认消息失败: %v, 消息ID: %s", err, message.ID)
} else {
glog.Debugf(ctx, "消息处理完成并已确认: %s", message.ID)
}
})
}