refactor(prompt): 重构提示词构建服务和回调处理
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ComposeMessages 核心拼接提示词主流程
|
||||
@@ -157,14 +158,14 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusFailed
|
||||
composeTask.ErrorMessage = req.ErrorMsg
|
||||
_ = gateway.SendCallback(ctx, composeTask)
|
||||
_ = gateway.SendCallback(ctx, composeTask, 0)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
// 1) 查模型配置
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
ModelName: composeTask.ModelName,
|
||||
@@ -172,6 +173,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
// 2) 根据运营商获取协议配置
|
||||
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
// ProviderName: model.OperatorName,
|
||||
//})
|
||||
|
||||
// 2) 解析结果
|
||||
var messages map[string]any
|
||||
@@ -179,10 +184,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||
case public.BuildTypeStruct:
|
||||
messages = map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{req.Messages},
|
||||
}
|
||||
messages = ParseStructResult(req.Messages, model.ResponseBody)
|
||||
default:
|
||||
messages = req.Messages
|
||||
}
|
||||
@@ -201,80 +203,113 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 5) 存储历史结果
|
||||
|
||||
// 6) 回调业务方
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
var epicycleId int64
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: messages,
|
||||
})
|
||||
}
|
||||
// 6) 拼接历史内容
|
||||
// 7) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.Messages = messages
|
||||
_ = gateway.SendCallback(ctx, composeTask)
|
||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseResult 解析回调结果
|
||||
// ParseResult 解析结果
|
||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// responseBody 为空,直接返回原始数据
|
||||
if responseBody == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 按 responseBody 路径取值
|
||||
contentStr, ok := raw[responseBody].(string)
|
||||
if !ok || contentStr == "" {
|
||||
contentVal := raw[responseBody]
|
||||
if contentVal == nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
||||
return map[string]any{
|
||||
"total_rounds": len(roundsArray),
|
||||
"rounds": roundsArray,
|
||||
// 已经是数组
|
||||
if arr, ok := contentVal.([]any); ok {
|
||||
rounds := gconv.Maps(arr)
|
||||
if len(rounds) > 0 {
|
||||
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{singleRound},
|
||||
}
|
||||
// 是字符串
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
return map[string]any{"total_rounds": len(arr), "rounds": arr}
|
||||
}
|
||||
|
||||
// 尝试解析为单对象
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
|
||||
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
|
||||
}
|
||||
|
||||
return map[string]any{"content": contentStr}
|
||||
}
|
||||
|
||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
|
||||
return nil
|
||||
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// 如果外层已有 rounds,直接返回
|
||||
if _, ok := raw["rounds"]; ok {
|
||||
return raw
|
||||
}
|
||||
if len(arr) == 0 {
|
||||
return nil
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
func tryParseAsMap(jsonStr string) map[string]any {
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(obj) == 0 {
|
||||
return nil
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
func ParseNodeResult(raw map[string]any) map[string]any {
|
||||
contentStr, ok := raw["content"].(string)
|
||||
if ok && contentStr != "" {
|
||||
var inner map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{inner},
|
||||
}
|
||||
contentVal := raw[responseBody]
|
||||
if contentVal == nil {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{raw},
|
||||
}
|
||||
}
|
||||
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" || contentStr == "0" {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{raw},
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
// 数组的每个元素包一层 content
|
||||
var rounds []map[string]any
|
||||
for _, item := range arr {
|
||||
rounds = append(rounds, map[string]any{"content": item})
|
||||
}
|
||||
return map[string]any{
|
||||
"total_rounds": len(rounds),
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为单个对象
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||
raw[responseBody] = parsed
|
||||
}
|
||||
|
||||
// 兜底:包标准结构
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{raw},
|
||||
|
||||
Reference in New Issue
Block a user