refactor(model): 优化模型网关的数据解析和任务处理逻辑
This commit is contained in:
@@ -2,6 +2,7 @@ package dao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
@@ -28,23 +29,57 @@ func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderPr
|
|||||||
return r.LastInsertId()
|
return r.LastInsertId()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get 查询协议配置
|
//// Get 查询协议配置
|
||||||
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
|
//func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||||
NoTenantId(ctx).
|
// NoTenantId(ctx).
|
||||||
OmitEmpty().
|
// OmitEmpty().
|
||||||
Where(entity.ProviderProtocolCol.Id, req.Id).
|
// Where(entity.ProviderProtocolCol.Id, req.Id).
|
||||||
Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
|
// Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
|
||||||
Where(entity.ProviderProtocolCol.Status, 1).
|
// Where(entity.ProviderProtocolCol.Status, 1).
|
||||||
Fields(fields).One()
|
// Fields(fields).One()
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// if r.IsEmpty() {
|
||||||
|
// return nil, nil
|
||||||
|
// }
|
||||||
|
// err = r.Struct(&res)
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
|
||||||
|
// Get 获取协议配置
|
||||||
|
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (*entity.ProviderProtocol, error) {
|
||||||
|
sql := fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND status = 1`, public.TableNameProviderProtocol)
|
||||||
|
args := make([]any, 0)
|
||||||
|
|
||||||
|
if req.Id != 0 {
|
||||||
|
sql += ` AND id = ?`
|
||||||
|
args = append(args, req.Id)
|
||||||
|
}
|
||||||
|
if req.ProviderName != "" {
|
||||||
|
sql += ` AND provider_name = ?`
|
||||||
|
args = append(args, req.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql += ` LIMIT 1`
|
||||||
|
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if r.IsEmpty() {
|
if r.IsEmpty() {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
err = r.Struct(&res)
|
|
||||||
return
|
var list []*entity.ProviderProtocol
|
||||||
|
if err = r.Structs(&list); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return list[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// List 列表查询
|
// List 列表查询
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ type SendCallbackReq struct {
|
|||||||
TaskId string `json:"taskId"`
|
TaskId string `json:"taskId"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
EpicycleId int64 `json:"epicycleId"`
|
EpicycleId int64 `json:"epicycleId"`
|
||||||
|
Messages map[string]any `json:"messages,omitempty"`
|
||||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,6 +166,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
|||||||
TaskId: composeTask.TaskId,
|
TaskId: composeTask.TaskId,
|
||||||
Status: composeTask.Status,
|
Status: composeTask.Status,
|
||||||
ErrorMsg: composeTask.ErrorMessage,
|
ErrorMsg: composeTask.ErrorMessage,
|
||||||
|
Messages: composeTask.ResultJson,
|
||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
// 3. 发送 POST 请求
|
// 3. 发送 POST 请求
|
||||||
|
|||||||
@@ -81,11 +81,9 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
|||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonString()
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON)
|
||||||
outputJSON, //【输出结构】 %s
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkOverallContent 检查整体内容是否超出窗口
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
@@ -110,9 +108,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
|||||||
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
||||||
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
||||||
}
|
}
|
||||||
if len(req.Consult) > 0 {
|
|
||||||
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
|
|
||||||
}
|
|
||||||
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
||||||
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway
|
|||||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||||
}
|
}
|
||||||
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
SQLBaseDO: beans.SQLBaseDO{
|
||||||
|
Creator: userInfo.UserName,
|
||||||
|
},
|
||||||
IsChatModel: 1,
|
IsChatModel: 1,
|
||||||
})
|
})
|
||||||
if err != nil || chatModel == nil {
|
if err != nil || chatModel == nil {
|
||||||
@@ -148,7 +150,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
|||||||
// 3) 解析 OSS 内容为消息
|
// 3) 解析 OSS 内容为消息
|
||||||
var messages map[string]any
|
var messages map[string]any
|
||||||
if len(ossContent) > 0 {
|
if len(ossContent) > 0 {
|
||||||
messages, _ = gjson.New(ossContent).Map(), nil
|
messages = gjson.New(ossContent).Map()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 处理失败
|
// 4) 处理失败
|
||||||
|
|||||||
Reference in New Issue
Block a user