diff --git a/dao/provider_protocol_dao.go b/dao/provider_protocol_dao.go index 00242e8..d0ea50d 100644 --- a/dao/provider_protocol_dao.go +++ b/dao/provider_protocol_dao.go @@ -2,6 +2,7 @@ package dao import ( "context" + "fmt" "prompts-core/consts/public" "prompts-core/model/entity" @@ -28,23 +29,57 @@ func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderPr return r.LastInsertId() } -// Get 查询协议配置 -func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). - NoTenantId(ctx). - OmitEmpty(). - Where(entity.ProviderProtocolCol.Id, req.Id). - Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询 - Where(entity.ProviderProtocolCol.Status, 1). - Fields(fields).One() +//// Get 查询协议配置 +//func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) { +// r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). +// NoTenantId(ctx). +// OmitEmpty(). +// Where(entity.ProviderProtocolCol.Id, req.Id). +// Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询 +// Where(entity.ProviderProtocolCol.Status, 1). +// Fields(fields).One() +// if err != nil { +// return nil, err +// } +// if r.IsEmpty() { +// return nil, nil +// } +// err = r.Struct(&res) +// return +//} + +// Get 获取协议配置 +func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (*entity.ProviderProtocol, error) { + sql := fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND status = 1`, public.TableNameProviderProtocol) + args := make([]any, 0) + + if req.Id != 0 { + sql += ` AND id = ?` + args = append(args, req.Id) + } + if req.ProviderName != "" { + sql += ` AND provider_name = ?` + args = append(args, req.ProviderName) + } + + sql += ` LIMIT 1` + + r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) if err != nil { return nil, err } if r.IsEmpty() { return nil, nil } - err = r.Struct(&res) - return + + var list []*entity.ProviderProtocol + if err = r.Structs(&list); err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + return list[0], nil } // List 列表查询 diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 6f19e2f..f2641e9 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -148,10 +148,11 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { // SendCallbackReq 发送回调的请求体 type SendCallbackReq struct { - TaskId string `json:"taskId"` - Status string `json:"status"` - EpicycleId int64 `json:"epicycleId"` - ErrorMsg string `json:"errorMsg,omitempty"` + TaskId string `json:"taskId"` + Status string `json:"status"` + EpicycleId int64 `json:"epicycleId"` + Messages map[string]any `json:"messages,omitempty"` + ErrorMsg string `json:"errorMsg,omitempty"` } // SendCallback 向业务方发送回调 @@ -165,6 +166,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle TaskId: composeTask.TaskId, Status: composeTask.Status, ErrorMsg: composeTask.ErrorMessage, + Messages: composeTask.ResultJson, EpicycleId: epicycleId, } // 3. 发送 POST 请求 diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 4fb896e..937d95b 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -81,11 +81,9 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, if err != nil || providerProtocol == nil { return "" } - outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString() + outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonString() - return fmt.Sprintf(providerProtocol.SystemPromptTemplate, - outputJSON, //【输出结构】 %s - ) + return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON) } // checkOverallContent 检查整体内容是否超出窗口 @@ -110,9 +108,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st if userFormText := buildUserFormText(req.UserForm); userFormText != "" { b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText)) } - if len(req.Consult) > 0 { - b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String())) - } if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" { b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts)) } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index b23310c..42eaf1b 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -41,7 +41,9 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) } chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + SQLBaseDO: beans.SQLBaseDO{ + Creator: userInfo.UserName, + }, IsChatModel: 1, }) if err != nil || chatModel == nil { @@ -148,7 +150,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { // 3) 解析 OSS 内容为消息 var messages map[string]any if len(ossContent) > 0 { - messages, _ = gjson.New(ossContent).Map(), nil + messages = gjson.New(ossContent).Map() } // 4) 处理失败