refactor(prompt): 优化任务等待机制并改进数据结构
This commit is contained in:
125
service/prompt/prompt_task_waiter.go
Normal file
125
service/prompt/prompt_task_waiter.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTaskNotFound = errors.New("task not found")
|
||||
ErrAlreadyNotified = errors.New("task already notified")
|
||||
TaskWaiter = NewManager()
|
||||
)
|
||||
|
||||
// Result 任务结果
|
||||
type Result struct {
|
||||
Data interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
// Manager 管理异步任务等待
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
waiters map[string]*waiter
|
||||
}
|
||||
|
||||
// waiter 单个等待者
|
||||
type waiter struct {
|
||||
result chan Result
|
||||
closed chan struct{}
|
||||
notifyOnce sync.Once
|
||||
}
|
||||
|
||||
// NewManager 创建管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
waiters: make(map[string]*waiter),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait 等待任务结果
|
||||
func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) {
|
||||
w := m.getOrCreate(taskID)
|
||||
defer m.remove(taskID)
|
||||
|
||||
select {
|
||||
case result := <-w.result:
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return result.Data, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-w.closed:
|
||||
// context取消后notify才到达的边缘情况
|
||||
select {
|
||||
case result := <-w.result:
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return result.Data, nil
|
||||
default:
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify 通知任务完成(安全,无阻塞)
|
||||
func (m *Manager) Notify(taskID string, data interface{}, err error) error {
|
||||
m.mu.Lock()
|
||||
w, exists := m.waiters[taskID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
|
||||
var notified bool
|
||||
w.notifyOnce.Do(func() {
|
||||
notified = true
|
||||
close(w.closed) // 先关闭信号channel
|
||||
// 根据err构造Result
|
||||
if err != nil {
|
||||
w.result <- Result{Error: err}
|
||||
} else {
|
||||
w.result <- Result{Data: data}
|
||||
}
|
||||
})
|
||||
m.mu.Unlock()
|
||||
|
||||
if !notified {
|
||||
return ErrAlreadyNotified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreate 获取或创建等待者
|
||||
func (m *Manager) getOrCreate(taskID string) *waiter {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if w, exists := m.waiters[taskID]; exists {
|
||||
return w
|
||||
}
|
||||
|
||||
w := &waiter{
|
||||
result: make(chan Result, 1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
m.waiters[taskID] = w
|
||||
return w
|
||||
}
|
||||
|
||||
// remove 安全移除等待者
|
||||
func (m *Manager) remove(taskID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.waiters, taskID)
|
||||
}
|
||||
|
||||
// ActiveCount 当前活跃等待数量
|
||||
func (m *Manager) ActiveCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.waiters)
|
||||
}
|
||||
Reference in New Issue
Block a user