Files
model-gateway/service/task/worker.go

463 lines
14 KiB
Go
Raw Normal View History

package task
2026-04-29 15:54:14 +08:00
import (
"bytes"
2026-04-29 15:54:14 +08:00
"context"
"encoding/json"
2026-04-29 15:54:14 +08:00
"fmt"
"io"
"net/http"
2026-04-29 15:54:14 +08:00
"strings"
"sync"
2026-04-29 15:54:14 +08:00
"time"
2026-05-12 13:45:08 +08:00
"unicode/utf8"
2026-04-29 15:54:14 +08:00
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
2026-04-29 15:54:14 +08:00
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson"
2026-04-29 15:54:14 +08:00
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
2026-04-29 15:54:14 +08:00
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
var (
body = task.RequestPayload.Body
maxRetry = model.RetryTimes
startTime = time.Now()
rawBytes []byte
result map[string]any
err error
)
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
2026-04-29 15:54:14 +08:00
// ============================================
// 1) 调用模型
// ============================================
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err = InvokeModel(ctx, model, body)
if err == nil {
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
}
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") &&
!strings.Contains(err.Error(), "InternalServiceError") {
w.failTask(ctx, task, startTime, err.Error())
return
}
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================
// 2) 解析校验 + 响应映射(可重试)
// ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = result
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 3) 上传 OSS可重试
// ============================================
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil {
break
}
g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
2026-04-29 15:54:14 +08:00
}
}
// ============================================
// 4) 成功收尾
// ============================================
task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.ResultFile = &entity.ResultFile{
OssFile: oss.FileAddressPrefix + oss.FileURL,
FileType: oss.FileFormat,
FileSize: int64(oss.FileSize),
}
task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
return
}
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat)
}
// asyncResult 异步任务结果
type asyncResult struct {
result map[string]any
err error
}
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body)
if err != nil {
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
asyncTaskChan.Store(taskID, ch)
defer func() {
asyncTaskChan.Delete(taskID)
close(ch)
}()
// 4. 阻塞等待回调或超时
timeout := time.Duration(model.TimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
select {
case res, ok := <-ch:
if !ok {
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
}
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
return res.result, res.err
case <-ctx.Done():
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
}
}
// NotifyAsyncResult 回调接口调用此方法通知结果
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
if ch, ok := asyncTaskChan.Load(taskID); ok {
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
}
}
// callModel 调用模型 + 提取文本结果
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
data, err := InvokeModel(ctx, model, body)
if err != nil {
return nil, err
2026-04-29 15:54:14 +08:00
}
contentType, _ := util.DetectFileType(data)
var textResult string
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
2026-04-29 15:54:14 +08:00
}
if textResult == "" {
return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType)
}
return gjson.New(textResult).Map(), nil
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
}
continue
}
// 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
}
// 3) 解析 + 校验
var parsed map[string]any
switch req.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
parsed, err = util.ParseAndValidate(mapped, model)
if err == nil {
return parsed, nil
}
lastErr = err
case public.BuildTypeStruct:
return util.ParseStructResult(mapped, entity.ResponseBody), nil
default:
return mapped, nil
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
}
// 4) 拼接错误信息到请求体,重调模型
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
body = rawResp
}
return body, nil
}
// injectErrorMessage 将错误信息插入到最后一个 user 消息之前
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("【上一轮输出错误,请修正】%s", err.Error())
// 找到最后一个 user 的位置
lastUserIdx := -1
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx == -1 {
return payload
}
// 在最后一个 user 之前插入错误消息
errMsgObj := map[string]any{
"role": "user",
"content": []map[string]any{{"type": "text", "text": errMsg}},
}
// 切片插入
messages = append(messages[:lastUserIdx], append([]any{errMsgObj}, messages[lastUserIdx:]...)...)
payload["messages"] = messages
return payload
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数
//_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 3构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 4构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.BodyToQuery(body)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(baseURL, "?") {
baseURL = baseURL + "&" + q.Encode()
} else {
baseURL = baseURL + "?" + q.Encode()
}
}
// 改用独立超时ctx隔绝外层截止
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil)
//req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, err
}
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
//req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+model.ApiKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 6发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 7读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 8检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}
// // InvokeModel 调用模型服务,返回二进制结果
//
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
// }
// failTask 任务失败统一处理
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
}