126 lines
2.3 KiB
Go
126 lines
2.3 KiB
Go
package prompt
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
)
|
|
|
|
var (
|
|
ErrTaskNotFound = errors.New("task not found")
|
|
ErrAlreadyNotified = errors.New("task already notified")
|
|
TaskWaiter = NewManager()
|
|
)
|
|
|
|
// Result 任务结果
|
|
type Result struct {
|
|
Data interface{}
|
|
Error error
|
|
}
|
|
|
|
// Manager 管理异步任务等待
|
|
type Manager struct {
|
|
mu sync.Mutex
|
|
waiters map[string]*waiter
|
|
}
|
|
|
|
// waiter 单个等待者
|
|
type waiter struct {
|
|
result chan Result
|
|
closed chan struct{}
|
|
notifyOnce sync.Once
|
|
}
|
|
|
|
// NewManager 创建管理器
|
|
func NewManager() *Manager {
|
|
return &Manager{
|
|
waiters: make(map[string]*waiter),
|
|
}
|
|
}
|
|
|
|
// Wait 等待任务结果
|
|
func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) {
|
|
w := m.getOrCreate(taskID)
|
|
defer m.remove(taskID)
|
|
|
|
select {
|
|
case result := <-w.result:
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return result.Data, nil
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-w.closed:
|
|
// context取消后notify才到达的边缘情况
|
|
select {
|
|
case result := <-w.result:
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return result.Data, nil
|
|
default:
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Notify 通知任务完成(安全,无阻塞)
|
|
func (m *Manager) Notify(taskID string, data interface{}, err error) error {
|
|
m.mu.Lock()
|
|
w, exists := m.waiters[taskID]
|
|
if !exists {
|
|
m.mu.Unlock()
|
|
return ErrTaskNotFound
|
|
}
|
|
|
|
var notified bool
|
|
w.notifyOnce.Do(func() {
|
|
notified = true
|
|
close(w.closed) // 先关闭信号channel
|
|
// 根据err构造Result
|
|
if err != nil {
|
|
w.result <- Result{Error: err}
|
|
} else {
|
|
w.result <- Result{Data: data}
|
|
}
|
|
})
|
|
m.mu.Unlock()
|
|
|
|
if !notified {
|
|
return ErrAlreadyNotified
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getOrCreate 获取或创建等待者
|
|
func (m *Manager) getOrCreate(taskID string) *waiter {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if w, exists := m.waiters[taskID]; exists {
|
|
return w
|
|
}
|
|
|
|
w := &waiter{
|
|
result: make(chan Result, 1),
|
|
closed: make(chan struct{}),
|
|
}
|
|
m.waiters[taskID] = w
|
|
return w
|
|
}
|
|
|
|
// remove 安全移除等待者
|
|
func (m *Manager) remove(taskID string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
delete(m.waiters, taskID)
|
|
}
|
|
|
|
// ActiveCount 当前活跃等待数量
|
|
func (m *Manager) ActiveCount() int {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
return len(m.waiters)
|
|
}
|