refactor(prompts-core): 重构代码结构和优化工具函数
This commit is contained in:
@@ -2,7 +2,6 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
@@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
|
||||
// handleBuild 通用构建处理
|
||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
@@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
var taskReq map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||
case public.BuildTypeNode:
|
||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
case public.BuildTypeStruct:
|
||||
@@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
RequestPayload: gconv.Map(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
@@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 3) 获取历史消息
|
||||
// 3) 获取历史消息 + 保存当前轮
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
var history []dto.FlatMessage
|
||||
if sessionId != "" && nodeId != "" {
|
||||
var epicycleId int64
|
||||
|
||||
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||
// 3.1 获取历史
|
||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
@@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if h != nil {
|
||||
history = h.Messages
|
||||
}
|
||||
|
||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
// 5) 注入历史到 rounds 中
|
||||
if protocol != nil && len(history) > 0 {
|
||||
// 5) 注入历史
|
||||
if len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
}
|
||||
|
||||
@@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return err
|
||||
}
|
||||
|
||||
// 7) 存储历史
|
||||
var epicycleId int64
|
||||
if sessionId != "" && nodeId != "" {
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 8) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
@@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
messages := parseMessagesForResponse(record.ResultJson)
|
||||
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseMessagesForResponse 解析用于响应的消息
|
||||
func parseMessagesForResponse(messages any) any {
|
||||
str, ok := messages.(string)
|
||||
if !ok || str == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取协议配置
|
||||
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: "火山引擎",
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 获取历史消息
|
||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: "88888888",
|
||||
NodeId: "node1",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 模拟roundsData数据
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println("[打印数据]", task.ResultJson)
|
||||
fmt.Println("[打印历史]", history.Messages)
|
||||
fmt.Println("[打印协议]", protocol)
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InjectHistory 插入历史会话
|
||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||
if protocol == nil || len(history) == 0 {
|
||||
return roundsData
|
||||
@@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco
|
||||
firstRound["messages"] = result
|
||||
return roundsData
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: record.ResultJson,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user