Files
common/ragflow/worker_pool.go
2026-03-12 08:50:55 +08:00

166 lines
4.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ragflow
import (
"context"
"gitee.com/red-future---jilin-g/common/redis"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
)
// WorkerPool RAGFlow 请求处理协程池
type WorkerPool struct {
pool *grpool.Pool
size int
}
// NewWorkerPool 创建协程池
// 参数:
// - size: 协程池大小,建议设置为 CPU 核心数的 2-4 倍
//
// 返回:
// - *WorkerPool: 协程池实例
// - error: 创建失败时返回错误
func NewWorkerPool(size int) (*WorkerPool, error) {
if size <= 0 {
return nil, gerror.New("协程池大小必须大于0")
}
pool := grpool.New(size)
return &WorkerPool{
pool: pool,
size: size,
}, nil
}
// Submit 提交任务到协程池
// 参数:
// - ctx: 上下文
// - task: 要执行的任务函数
//
// 返回error 提交失败时返回错误
func (w *WorkerPool) Submit(ctx context.Context, task func(ctx context.Context)) error {
return w.pool.Add(ctx, func(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
glog.Errorf(ctx, "协程池任务执行 panic: %v", r)
}
}()
task(ctx)
})
}
// Size 获取协程池大小
func (w *WorkerPool) Size() int {
return w.size
}
// Jobs 获取当前等待执行的任务数量
func (w *WorkerPool) Jobs() int {
return w.pool.Jobs()
}
// Close 关闭协程池
func (w *WorkerPool) Close() {
w.pool.Close()
}
// WorkerStats 协程池统计信息
type WorkerStats struct {
PoolSize int // 协程池大小
Jobs int // 等待执行的任务数
}
// Stats 获取协程池统计信息
func (w *WorkerPool) Stats() WorkerStats {
return WorkerStats{
PoolSize: w.size,
Jobs: w.pool.Jobs(),
}
}
// PrintStats 打印协程池统计信息
func (w *WorkerPool) PrintStats(ctx context.Context) {
stats := w.Stats()
glog.Infof(ctx, "协程池统计 - 池大小: %d, 等待任务: %d", stats.PoolSize, stats.Jobs)
}
// QueueProcessor 队列处理器,从 Redis 队列中取出任务并提交到协程池
type QueueProcessor struct {
pool *WorkerPool
queueKey string
timeout int
stopChan chan struct{}
handleFunc func(ctx context.Context, message string) error
}
// NewQueueProcessor 创建队列处理器
// 参数:
// - pool: 协程池
// - queueKey: Redis 队列键名
// - timeout: 从队列取消息的超时时间(秒)
// - handleFunc: 消息处理函数
func NewQueueProcessor(pool *WorkerPool, queueKey string, timeout int, handleFunc func(ctx context.Context, message string) error) *QueueProcessor {
return &QueueProcessor{
pool: pool,
queueKey: queueKey,
timeout: timeout,
stopChan: make(chan struct{}),
handleFunc: handleFunc,
}
}
// Start 启动队列处理器
// 会阻塞运行,持续从 Redis 队列中取出消息并提交到协程池处理
func (q *QueueProcessor) Start(ctx context.Context) error {
glog.Infof(ctx, "队列处理器启动 - 队列: %s, 超时: %ds", q.queueKey, q.timeout)
for {
select {
case <-q.stopChan:
glog.Info(ctx, "队列处理器收到停止信号")
return nil
default:
// 从 Redis 队列中取出消息
message, err := q.fetchMessage(ctx)
if err != nil {
glog.Errorf(ctx, "从队列取消息失败: %v", err)
continue
}
// 队列为空,继续等待
if message == "" {
continue
}
// 提交到协程池处理
if err := q.submitTask(ctx, message); err != nil {
glog.Errorf(ctx, "提交任务到协程池失败: %v", err)
}
}
}
}
// Stop 停止队列处理器
func (q *QueueProcessor) Stop() {
close(q.stopChan)
}
// fetchMessage 从 Redis 队列中取出消息
func (q *QueueProcessor) fetchMessage(ctx context.Context) (string, error) {
// 调用 Redis 队列的 PopFromQueue 方法从队列中取出消息
return redis.PopFromQueue(ctx, q.queueKey, q.timeout)
}
// submitTask 将消息处理任务提交到协程池
func (q *QueueProcessor) submitTask(ctx context.Context, message string) error {
return q.pool.Submit(ctx, func(ctx context.Context) {
if err := q.handleFunc(ctx, message); err != nil {
glog.Errorf(ctx, "处理消息失败: %v, 消息: %s", err, message)
}
})
}