refactor(model): 优化模型网关的数据解析和任务处理逻辑

This commit is contained in:
2026-06-17 14:34:49 +08:00
parent 0d52b631b9
commit eb28c2d1e0
4 changed files with 58 additions and 24 deletions

View File

@@ -2,6 +2,7 @@ package dao
import ( import (
"context" "context"
"fmt"
"prompts-core/consts/public" "prompts-core/consts/public"
"prompts-core/model/entity" "prompts-core/model/entity"
@@ -28,23 +29,57 @@ func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderPr
return r.LastInsertId() return r.LastInsertId()
} }
// Get 查询协议配置 //// Get 查询协议配置
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) { //func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). // r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
NoTenantId(ctx). // NoTenantId(ctx).
OmitEmpty(). // OmitEmpty().
Where(entity.ProviderProtocolCol.Id, req.Id). // Where(entity.ProviderProtocolCol.Id, req.Id).
Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询 // Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
Where(entity.ProviderProtocolCol.Status, 1). // Where(entity.ProviderProtocolCol.Status, 1).
Fields(fields).One() // Fields(fields).One()
// if err != nil {
// return nil, err
// }
// if r.IsEmpty() {
// return nil, nil
// }
// err = r.Struct(&res)
// return
//}
// Get 获取协议配置
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (*entity.ProviderProtocol, error) {
sql := fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND status = 1`, public.TableNameProviderProtocol)
args := make([]any, 0)
if req.Id != 0 {
sql += ` AND id = ?`
args = append(args, req.Id)
}
if req.ProviderName != "" {
sql += ` AND provider_name = ?`
args = append(args, req.ProviderName)
}
sql += ` LIMIT 1`
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if r.IsEmpty() { if r.IsEmpty() {
return nil, nil return nil, nil
} }
err = r.Struct(&res)
return var list []*entity.ProviderProtocol
if err = r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
} }
// List 列表查询 // List 列表查询

View File

@@ -151,6 +151,7 @@ type SendCallbackReq struct {
TaskId string `json:"taskId"` TaskId string `json:"taskId"`
Status string `json:"status"` Status string `json:"status"`
EpicycleId int64 `json:"epicycleId"` EpicycleId int64 `json:"epicycleId"`
Messages map[string]any `json:"messages,omitempty"`
ErrorMsg string `json:"errorMsg,omitempty"` ErrorMsg string `json:"errorMsg,omitempty"`
} }
@@ -165,6 +166,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
TaskId: composeTask.TaskId, TaskId: composeTask.TaskId,
Status: composeTask.Status, Status: composeTask.Status,
ErrorMsg: composeTask.ErrorMessage, ErrorMsg: composeTask.ErrorMessage,
Messages: composeTask.ResultJson,
EpicycleId: epicycleId, EpicycleId: epicycleId,
} }
// 3. 发送 POST 请求 // 3. 发送 POST 请求

View File

@@ -81,11 +81,9 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
if err != nil || providerProtocol == nil { if err != nil || providerProtocol == nil {
return "" return ""
} }
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString() outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON)
outputJSON, //【输出结构】 %s
)
} }
// checkOverallContent 检查整体内容是否超出窗口 // checkOverallContent 检查整体内容是否超出窗口
@@ -110,9 +108,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
if userFormText := buildUserFormText(req.UserForm); userFormText != "" { if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText)) b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
} }
if len(req.Consult) > 0 {
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
}
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" { if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts)) b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
} }

View File

@@ -41,7 +41,9 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
} }
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, SQLBaseDO: beans.SQLBaseDO{
Creator: userInfo.UserName,
},
IsChatModel: 1, IsChatModel: 1,
}) })
if err != nil || chatModel == nil { if err != nil || chatModel == nil {
@@ -148,7 +150,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
// 3) 解析 OSS 内容为消息 // 3) 解析 OSS 内容为消息
var messages map[string]any var messages map[string]any
if len(ossContent) > 0 { if len(ossContent) > 0 {
messages, _ = gjson.New(ossContent).Map(), nil messages = gjson.New(ossContent).Map()
} }
// 4) 处理失败 // 4) 处理失败