Files
prompts-core/service/prompt/prompt_compose_service.go

355 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
if aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
// 调用推理模型
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", 0, errors.New("网关未返回taskId")
}
return taskID, id, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
}
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Messages)
default:
messages = req.Messages
}
// 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 3. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
}
gateway.SendCallback(ctx, successTask)
}
}
return err
}
// ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if !ok || contentStr == "" {
return raw
}
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return map[string]any{
"total_rounds": len(roundsArray),
"rounds": roundsArray,
}
}
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{singleRound},
}
}
return map[string]any{"content": contentStr}
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
func ParseNodeResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{inner},
}
}
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{raw},
}
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}