Files
model-gateway/service/task/worker.go

459 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package task
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"unicode/utf8"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
var (
body = task.RequestPayload.Body
maxRetry = model.RetryTimes
startTime = time.Now()
rawBytes []byte
result map[string]any
err error
)
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================
// 1) 调用模型
// ============================================
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err = InvokeModel(ctx, model, body)
if err == nil {
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
}
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") &&
!strings.Contains(err.Error(), "InternalServiceError") {
w.failTask(ctx, task, startTime, err.Error())
return
}
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================
// 2) 解析校验 + 响应映射(可重试)
// ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = result
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 3) 上传 OSS可重试
// ============================================
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil {
break
}
g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// ============================================
// 4) 成功收尾
// ============================================
task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.ResultFile = &entity.ResultFile{
OssFile: oss.FileAddressPrefix + oss.FileURL,
FileType: oss.FileFormat,
FileSize: int64(oss.FileSize),
}
task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
return
}
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat)
}
// asyncResult 异步任务结果
type asyncResult struct {
result map[string]any
err error
}
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body)
if err != nil {
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
asyncTaskChan.Store(taskID, ch)
defer func() {
asyncTaskChan.Delete(taskID)
close(ch)
}()
// 4. 阻塞等待回调或超时
timeout := time.Duration(model.TimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
select {
case res, ok := <-ch:
if !ok {
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
}
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
return res.result, res.err
case <-ctx.Done():
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
}
}
// NotifyAsyncResult 回调接口调用此方法通知结果
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
if ch, ok := asyncTaskChan.Load(taskID); ok {
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
}
}
// callModel 调用模型 + 提取文本结果
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
data, err := InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
contentType, _ := util.DetectFileType(data)
var textResult string
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
if textResult == "" {
return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType)
}
return gjson.New(textResult).Map(), nil
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
}
continue
}
// 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
}
// 3) 解析 + 校验
var parsed map[string]any
switch req.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
parsed, err = util.ParseAndValidate(mapped, model)
if err == nil {
return parsed, nil
}
lastErr = err
case public.BuildTypeStruct:
return util.ParseStructResult(mapped, entity.ResponseBody), nil
default:
return mapped, nil
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
}
// 4) 拼接错误信息到请求体,重调模型
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
body = rawResp
}
return body, nil
}
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数
//_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 3构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 4构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.BodyToQuery(body)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(baseURL, "?") {
baseURL = baseURL + "&" + q.Encode()
} else {
baseURL = baseURL + "?" + q.Encode()
}
}
// 改用独立超时ctx隔绝外层截止
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil)
//req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, err
}
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
//req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+model.ApiKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 6发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 7读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 8检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}
// // InvokeModel 调用模型服务,返回二进制结果
//
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
// }
// failTask 任务失败统一处理
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
}