Files
prompts-core/service/prompt/prompt_task_waiter.go

126 lines
2.3 KiB
Go

package prompt
import (
"context"
"errors"
"sync"
)
var (
ErrTaskNotFound = errors.New("task not found")
ErrAlreadyNotified = errors.New("task already notified")
TaskWaiter = NewManager()
)
// Result 任务结果
type Result struct {
Data interface{}
Error error
}
// Manager 管理异步任务等待
type Manager struct {
mu sync.Mutex
waiters map[string]*waiter
}
// waiter 单个等待者
type waiter struct {
result chan Result
closed chan struct{}
notifyOnce sync.Once
}
// NewManager 创建管理器
func NewManager() *Manager {
return &Manager{
waiters: make(map[string]*waiter),
}
}
// Wait 等待任务结果
func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) {
w := m.getOrCreate(taskID)
defer m.remove(taskID)
select {
case result := <-w.result:
if result.Error != nil {
return nil, result.Error
}
return result.Data, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-w.closed:
// context取消后notify才到达的边缘情况
select {
case result := <-w.result:
if result.Error != nil {
return nil, result.Error
}
return result.Data, nil
default:
return nil, ctx.Err()
}
}
}
// Notify 通知任务完成(安全,无阻塞)
func (m *Manager) Notify(taskID string, data interface{}, err error) error {
m.mu.Lock()
w, exists := m.waiters[taskID]
if !exists {
m.mu.Unlock()
return ErrTaskNotFound
}
var notified bool
w.notifyOnce.Do(func() {
notified = true
close(w.closed) // 先关闭信号channel
// 根据err构造Result
if err != nil {
w.result <- Result{Error: err}
} else {
w.result <- Result{Data: data}
}
})
m.mu.Unlock()
if !notified {
return ErrAlreadyNotified
}
return nil
}
// getOrCreate 获取或创建等待者
func (m *Manager) getOrCreate(taskID string) *waiter {
m.mu.Lock()
defer m.mu.Unlock()
if w, exists := m.waiters[taskID]; exists {
return w
}
w := &waiter{
result: make(chan Result, 1),
closed: make(chan struct{}),
}
m.waiters[taskID] = w
return w
}
// remove 安全移除等待者
func (m *Manager) remove(taskID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.waiters, taskID)
}
// ActiveCount 当前活跃等待数量
func (m *Manager) ActiveCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.waiters)
}