refactor(prompt): 重构提示词构建服务与数据模型

This commit is contained in:
2026-05-20 11:36:39 +08:00
parent c49144794d
commit 35bc3bd6ec
24 changed files with 1682 additions and 759 deletions

View File

@@ -4,65 +4,113 @@ import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto/prompt"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
// 1. 统一 Prompt IR
switch req.BuildType {
case 1: //构建提示词请求
ir.AddSystem(promptBuild(ctx, req, model))
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
case 2: //构建节点请求
ir.AddUser(NodeBuild(ctx, req))
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// 2. 获取协议配置
protocol, err := GetProtocolByProvider(ctx, "qwen")
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
if err != nil {
return nil, err
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
// 3. 编译为 Provider Request
providerReq, err := Compile(ir, protocol, chatModel)
providerReq, err := Compile(ir, protocol, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("编译请求失败: %w", err)
}
// 4. 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"modelName": req.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuild 构建系统提示词
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, model)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
return map[string]any{
"modelName": model.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "qwen",
ProviderName: model.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
@@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent
}
outputJSON := util.JSONPretty(model.RequestMapping)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormFullText)
`, util.FormToJSON(req.Form), userFormContent)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName,
maxWindowSize,
availableWindow,
totalRounds,
totalRounds,
totalRounds,
outputJSON,
inputInfo,
totalRounds,
)
}
// 构建用户提示词
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName, // 请求模型名称
"promptInfo": prompt, // 数据库提示信息
"form": req.Form, // 系统表单
"userForm": req.UserForm, // 用户表单
"userFiles": req.UserFiles, //文件url
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如xmljson,yaml
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
}
return util.MustMarshal(payload)
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(text)
}
return builder.String()
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.FormToJSON(req.UserForm)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -5,171 +5,229 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/dao"
"prompts-core/model/entity"
"strings"
"time"
"prompts-core/common/util"
"prompts-core/consts/public"
promptDto "prompts-core/model/dto/prompt"
"prompts-core/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
var (
epicycleId int64
taskID string
history []map[string]any
message map[string]any
err error
taskRecord *entity.ComposeTask
)
// 获取模型信息
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 根据构建类型进行判断处理
switch req.BuildType {
//提示词构建
case 1:
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
//1. 获取历史会话
history, err = GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil // 出错就用空的,不影响主流程
}
// 重试循环
for attempt := 0; attempt <= 0; attempt++ {
if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
}
// 2. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue
}
// 3. 保存记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
if err != nil {
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue
}
// 4. 等待结果
taskRecord, err = waitForResult(ctx, taskID)
if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue
}
// 校验结果
message = parsePromptBuild(taskRecord, chatModel)
if message != nil && util.IsMessageValid(message) {
break
}
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
message = nil
}
if message == nil {
return nil, errors.New("推理模型调用失败,请稍后再试")
}
//5.创建会话记录
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: message,
})
//节点构建
case 2:
//1. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, err
}
//2. 保存相关记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
//5. 等待结果
taskRecord, err := waitForResult(ctx, taskID)
if err != nil {
return nil, err
}
message = parseNodeBuild(taskRecord)
default:
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
return &promptDto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
if err = validateUserForm(ctx, req, aiModel); err != nil {
return nil, err
}
return &promptDto.ComposeMessagesRes{
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return handleDefaultCase(ctx, req)
}
}
// validateUserForm 校验用户表单
func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
history, err := GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
var message *dto.MultiRoundResult
var taskRecord *entity.ComposeTask
for attempt := 0; attempt <= 0; attempt++ {
if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
}
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue
}
if err = saveComposeTask(ctx, taskID, req); err != nil {
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue
}
taskRecord, err = waitForResult(ctx, taskID)
if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue
}
message = parsePromptBuild(taskRecord, chatModel)
if message != nil {
break
}
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
}
if message == nil {
return nil, errors.New("推理模型调用失败,请稍后再试")
}
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: message,
})
if err != nil {
g.Log().Errorf(ctx, "创建会话记录失败: %v", err)
}
return &dto.ComposeMessagesRes{
Messages: message,
EpicycleId: epicycleId,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
taskRecord, err := waitForResult(ctx, taskID)
if err != nil {
return nil, fmt.Errorf("等待结果失败: %w", err)
}
message := parseNodeBuild(taskRecord)
return &dto.ComposeMessagesRes{
Messages: message,
EpicycleId: 0,
}, nil
}
// handleDefaultCase 处理默认情况
func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
if err != nil {
return nil, fmt.Errorf("创建会话记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
}
// saveComposeTask 保存组合任务
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
return err
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
// 1. 获取当前用户的会话模型
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
chatModel, err := getChatModel(ctx, userInfo.UserName)
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
// 2. 获取要构建的模型信息
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
}
if aiModel == nil {
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
}
return chatModel, aiModel, nil
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
// 构建推理模型请求
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
// 创建网关任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
@@ -186,96 +244,131 @@ func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq,
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 1. 查数据库
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
if completed, result := checkTaskCompletion(record); completed {
return result, nil
}
}
// 2. 查网关状态
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
if err = syncGatewayTaskState(ctx, taskID, record); err != nil {
g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err)
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
case <-ticker.C:
}
}
}
// checkTaskCompletion 检查任务是否完成
func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) {
if record == nil {
return false, nil
}
switch record.Status {
case public.ComposeStatusSuccess:
return true, record
case public.ComposeStatusFailed:
errMsg := strings.TrimSpace(record.ErrorMessage)
if errMsg == "" {
return true, nil
}
return true, nil
default:
return false, nil
}
}
// syncGatewayTaskState 同步网关任务状态
func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error {
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
return fmt.Errorf("查询网关状态失败: %w", err)
}
switch state {
case 2:
return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "")
case 3:
updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败")
return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
return nil
}
// updateTaskStatus 更新任务状态
func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error {
task := &entity.ComposeTask{
TaskId: taskID,
Status: status,
}
if errorMsg != "" {
task.ErrorMessage = errorMsg
}
_, err := dao.ComposeTask.Update(ctx, task)
return err
}
// parsePromptBuild 解析提示词构建结果BuildType == 1
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult {
if taskRecord == nil {
return nil
}
mapped := parseTaskMessages(taskRecord.Messages)
if mapped == nil {
return createDefaultResult(nil)
}
// 1. 解析 Messages
contentField := getContentField(model)
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return createDefaultResult(mapped)
}
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
}
}
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{singleRound},
}
}
return createDefaultResult(map[string]any{"content": contentStr})
}
// parseTaskMessages 解析任务消息
func parseTaskMessages(messages any) map[string]any {
var mapped map[string]any
switch v := taskRecord.Messages.(type) {
switch v := messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &mapped)
@@ -289,115 +382,137 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
json.Unmarshal(b, &mapped)
}
// 2. 解析模型 ResponseMapping 获取 content 字段名
contentField := "content" // 默认值
if model != nil {
var respMapping map[string]string
switch v := model.ResponseMapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &respMapping)
}
case string:
json.Unmarshal([]byte(v), &respMapping)
case map[string]interface{}:
respMapping = make(map[string]string)
for k, val := range v {
if s, ok := val.(string); ok {
respMapping[k] = s
}
}
}
// 从映射中找到 content 对应的字段名
for k, v := range respMapping {
if strings.Contains(v, "content") {
contentField = k
break
}
}
}
// 3. 提取 content 的值
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return mapped
}
// 4. 解析 content 内的 JSON
var innerData map[string]any
json.Unmarshal([]byte(contentStr), &innerData)
return innerData
return mapped
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
if taskRecord == nil {
// tryParseAsArray 尝试将字符串解析为数组
func tryParseAsArray(contentStr string) []any {
var roundsArray []any
if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil {
return nil
}
var result map[string]any
switch v := taskRecord.Messages.(type) {
return roundsArray
}
// tryParseAsObject 尝试将字符串解析为对象
func tryParseAsObject(contentStr string) any {
var singleRound any
if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil {
return nil
}
return singleRound
}
// createDefaultResult 创建默认结果
func createDefaultResult(data any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{data},
}
}
// getContentField 从模型 ResponseMapping 中获取 content 字段名
func getContentField(model *entity.AsynchModel) string {
if model == nil {
return "content"
}
respMapping := parseResponseMapping(model.ResponseMapping)
for k, v := range respMapping {
if strings.Contains(v, "content") {
return k
}
}
return "content"
}
// parseResponseMapping 解析响应映射
func parseResponseMapping(mapping any) map[string]string {
result := make(map[string]string)
switch v := mapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &result)
}
case string:
json.Unmarshal([]byte(v), &result)
case map[string]any:
result = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &result)
case map[string]interface{}:
for k, val := range v {
if s, ok := val.(string); ok {
result[k] = s
}
}
}
return result
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
if taskRecord == nil {
return nil
}
result := parseTaskMessages(taskRecord.Messages)
if result == nil {
result = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{result},
}
}
// Callback 回调处理
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return err
return fmt.Errorf("查询任务失败: %w", err)
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := util.ParseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
}
return handleCallbackSuccess(ctx, req)
}
// handleCallbackFailure 处理回调失败
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: errorMsg,
})
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
result, err := util.ParseOutput(req.Text)
if err != nil {
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
return fmt.Errorf("解析模型输出失败: %w", err)
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
@@ -407,34 +522,43 @@ func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, err
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
messages := parseMessagesForResponse(record.Messages)
return &promptDto.GetComposeTaskRes{
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}

View File

@@ -10,10 +10,15 @@ import (
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/service/gateway"
)
"github.com/gogf/gf/v2/frame/g"
const (
bytesPerKB = 1024
bytesPerMB = 1024 * 1024
)
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
@@ -24,51 +29,49 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
return result
}
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
if url == "" {
continue
}
if util.IsBannedExtension(url) {
if url == "" || util.IsBannedExtension(url) {
continue
}
if util.IsZipExtension(url) {
zipTexts := fetchZipFileTexts(ctx, client, url)
for k, v := range zipTexts {
result[k] = v
}
mergeMap(result, fetchZipFileTexts(ctx, client, url))
continue
}
text, err := fetchFileContent(ctx, client, url)
if err != nil {
continue
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text
}
if text == "" {
continue
}
text = util.CleanSymbols(text)
result[url] = text
}
return result
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
}
// fetchAndCleanFileContent 获取并清理文件内容
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
text, err := fetchFileContent(ctx, client, url)
if err != nil || text == "" {
return ""
}
return util.CleanSymbols(text)
}
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
zipBytes, err := downloadFile(client, url,
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, url, maxSize)
if err != nil {
return result
}
@@ -78,61 +81,61 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() {
if shouldSkipZipEntry(file.Name) {
continue
}
fileName := file.Name
if util.IsBannedExtension(fileName) {
continue
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
result[url+"::"+file.Name] = text
}
if util.IsZipExtension(fileName) {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
contentType := http.DetectContentType(content)
if !util.IsReadableContentType(contentType) {
continue
}
text := util.CleanSymbols(string(content))
if text == "" {
continue
}
key := url + "::" + fileName
result[key] = text
}
return result
}
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
func shouldSkipZipEntry(fileName string) bool {
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
}
// extractZipEntryContent 提取 zip 条目内容
func extractZipEntryContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if !util.IsReadableContentType(http.DetectContentType(content)) {
return ""
}
text := util.CleanSymbols(string(content))
if text == "" {
return ""
}
return text
}
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
@@ -140,19 +143,24 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
return body, nil
}
// fetchFileContent 获取单个文本文件内容
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", err
return "", fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
@@ -162,16 +170,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
contentType := resp.Header.Get("Content-Type")
if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("unreadable content-type: %s", contentType)
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
}
body, err := io.ReadAll(
io.LimitReader(resp.Body,
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
),
)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return "", err
return "", fmt.Errorf("读取响应失败: %w", err)
}
return strings.TrimSpace(string(body)), nil
@@ -186,27 +191,26 @@ func SkillMdContent(ctx context.Context, skillName string) string {
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
}
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, fullUrl,
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
zipBytes, err := downloadFile(client, fullUrl, maxSize)
if err != nil {
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil {
if err != nil || len(mdContents) == 0 {
return ""
}
if len(mdContents) == 0 {
return ""
}
return buildSkillMarkdown(skillResp, mdContents)
}
// buildSkillMarkdown 构建技能 Markdown 内容
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
@@ -227,35 +231,53 @@ func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, er
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, err
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() {
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
continue
}
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
if len(content) > 0 {
result[file.Name] = strings.TrimSpace(string(content))
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
result[file.Name] = content
}
}
return result, nil
}
// isMarkdownFile 判断是否为 Markdown 文件
func isMarkdownFile(fileName string) bool {
return strings.HasSuffix(strings.ToLower(fileName), ".md")
}
// readMarkdownFileContent 读取 Markdown 文件内容
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if len(content) == 0 {
return ""
}
return strings.TrimSpace(string(content))
}
// createHTTPClient 创建 HTTP 客户端
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
return &http.Client{
Timeout: timeout,
}
}

View File

@@ -0,0 +1,75 @@
# Prompts-Core提示词核心服务
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
---
## 项目简介
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
### 核心价值
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
---
## 核心功能
### 1. 提示词构建引擎
#### 多模态支持
| 类型 | 说明 | 适用场景 |
|------|------|----------|
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
#### 构建模式
- **BuildType 1提示词构建**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
- **BuildType 2节点构建**:工作流路由决策,根据上下文选择节点 ID
#### 分批处理
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
### 2. 会话管理系统
- **双层存储**Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
### 3. 协议适配器
通过配置动态支持多种模型协议:
- 角色映射system/user/assistant → 目标协议角色
- 内容字段映射content → parts.text 等
- 消息顺序控制:灵活配置拼接顺序
- 请求模板渲染:支持占位符替换
### 4. 任务调度
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
- **重试机制**:可配置最大重试次数(默认 3 次)
- **超时保护**:默认 300 秒超时
---
## 技术架构
### 技术栈
| 组件 | 版本 | 用途 |
|------|------|------|
| Go | 1.26.0 | 编程语言 |
| GoFrame | v2.10.0 | Web 框架 |
| PostgreSQL | - | 关系型数据库 |
| Redis | - | 缓存与会话存储 |
| Consul | - | 服务注册与发现 |
| Jaeger | - | 分布式链路追踪 |
### 架构图

View File

@@ -20,11 +20,27 @@ type PromptIR struct {
// Segment 消息片段
type Segment struct {
Type string `json:"type"` // text/image
Type string `json:"type"`
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"`
Field string `json:"field"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
@@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR {
}
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString("System: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Role)
builder.WriteString(": ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString("User: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
@@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
// 1. 系统消息
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
@@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any {
})
}
// 2. 历史消息
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
@@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any {
})
}
// 3. 用户消息
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
@@ -97,74 +159,35 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
if err != nil || entity == nil {
return nil, err
}
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
fmt.Println("entity打印", entity)
return parseProtocol(entity), nil
}
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
}
// MergeOrder: any → []string
if e.MergeOrder != nil {
b, _ := json.Marshal(e.MergeOrder)
json.Unmarshal(b, &p.MergeOrder)
}
// RoleMapping: any → map[string]string
if e.RoleMapping != nil {
b, _ := json.Marshal(e.RoleMapping)
json.Unmarshal(b, &p.RoleMapping)
}
// ContentMapping: any → ContentMapping
if e.ContentMapping != nil {
b, _ := json.Marshal(e.ContentMapping)
json.Unmarshal(b, &p.ContentMapping)
}
// RequestTemplate: any → map[string]any
if e.RequestTemplate != nil {
b, _ := json.Marshal(e.RequestTemplate)
json.Unmarshal(b, &p.RequestTemplate)
}
fmt.Printf("parseProtocol: %+v\n", p)
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
return p
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"` // direct/parts
Field string `json:"field"` // content/text
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
// 1. 按 merge_order 拼接消息
messages := mergeByOrder(ir, p.MergeOrder)
// 2. 角色映射
messages = mapRoles(messages, p.RoleMapping)
// 3. 内容字段映射
messages = mapContent(messages, p.ContentMapping)
// 4. 按 target_field + request_template 构建请求体
return buildRequest(messages, p, chatModel), nil
}
@@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
}
}
}
return messages
}
@@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
@@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
switch cm.Type {
case "parts":
// Gemini 格式: {"parts": [{"text": "..."}]}
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
// direct: {"content": "..."}
msg[cm.Field] = content
}
}
return messages
}
@@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
@@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
b, _ := json.Marshal(tmpl)
str := string(b)
// 替换 {{model}}
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// 替换 {{messages}}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}

View File

@@ -9,15 +9,16 @@ import (
"github.com/gogf/gf/v2/frame/g"
)
// ==================== Redis 操作 ====================
const (
redisKeyPrefix = "chat:session:%s"
)
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := fmt.Sprintf("chat:session:%s", sessionId)
key := formatRedisKey(sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
data := map[string]any{
"sessionId": sessionId,
@@ -31,18 +32,29 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
return fmt.Errorf("序列化会话数据失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err
}
return nil
}
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
@@ -51,7 +63,7 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := fmt.Sprintf("chat:session:%s", sessionId)
key := formatRedisKey(sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
@@ -62,8 +74,17 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
return []map[string]any{}, nil
}
sessions := parseRedisSessions(ctx, result.Strings())
reverseSlice(sessions)
return sessions, nil
}
// parseRedisSessions 解析Redis会话数据
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
var sessions []map[string]any
values := result.Strings()
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
@@ -73,12 +94,14 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
sessions = append(sessions, data)
}
// 反转Redis 最新在前 → 时间正序)
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions
}
return sessions, nil
// reverseSlice 反转切片
func reverseSlice(s []map[string]any) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
@@ -92,23 +115,31 @@ func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map
return []map[string]any{}, nil
}
return flattenHistoryMessages(historyData), nil
}
// flattenHistoryMessages 扁平化历史消息
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
var messages []map[string]any
for _, round := range historyData {
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
for _, m := range reqMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
for _, m := range respMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
appendMessagesFromField(round, "requestContent", &messages)
appendMessagesFromField(round, "responseContent", &messages)
}
return messages, nil
return messages
}
// appendMessagesFromField 从指定字段追加消息
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]interface{})
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
*messages = append(*messages, msg)
}
}
}

View File

@@ -3,112 +3,164 @@ package prompt
import (
"context"
"fmt"
sessionDao "prompts-core/dao"
"prompts-core/model/entity"
"prompts-core/common/util"
sessionDto "prompts-core/model/dto/prompt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
// 1. 解析AI返回的文本
// SessionCallback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
// 2. 更新数据库
result["role"] = "assistant"
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
// 3. 获取当前轮次完整数据
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 4. 转换 json 并存入 Redis
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return nil, err
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &sessionDto.SessionCallbackRes{}, nil
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
ResponseContent: response,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err)
}
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
return session, nil
}
// saveSessionToRedis 保存会话到Redis
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
}
return nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
// 2. Redis 没有 → fallback DB
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
messages := extractMessagesFromSessions(sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
// request
reqMsgs := util.ConvertToMessages(session.RequestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
messages = append(messages, m)
}
}
// response
respMsgs := util.ConvertToMessages(session.ResponseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
messages = append(messages, m)
}
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
}
// 3. 回写 Redis
return messages
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
*messages = append(*messages, m)
}
}
}
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
*messages = append(*messages, m)
}
}
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
return messages, nil
}

View File

@@ -0,0 +1,135 @@
package prompt
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}
availableWindow := util.GetAvailableWindow(model.TokenConfig)
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
if len(batches) <= 1 {
return req, 1, nil
}
newUserForm := buildBatchedUserForm(batches)
newReq := *req
newReq.UserForm = newUserForm
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
len(req.UserForm), len(batches))
return &newReq, len(batches), nil
}
// buildBatchedUserForm 构建分批后的用户表单
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
newUserForm := make([]map[string]any, 0, len(batches))
for i, batch := range batches {
combinedText := combineBatchText(batch)
newUserForm = append(newUserForm, map[string]any{
"batch_index": i + 1,
"total_batches": len(batches),
"text": combinedText,
"item_count": len(batch),
})
}
return newUserForm
}
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
func combineBatchText(batch []map[string]any) string {
var builder strings.Builder
for j, item := range batch {
itemText := getItemText(item)
if itemText == "" {
continue
}
if j > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(itemText)
}
return builder.String()
}
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
if len(userForm) == 0 {
return [][]map[string]any{}
}
batches := make([][]map[string]any, 0)
currentBatch := make([]map[string]any, 0)
currentTokens := 0
for i, item := range userForm {
itemText := getItemText(item)
itemTokens := util.CalculateTokens(itemText, tokenConfig)
// 单个元素超过窗口,单独成一批
if itemTokens > maxTokens {
if len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
batches = append(batches, []map[string]any{item})
continue
}
// 判断是否需要新开一批
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
currentBatch = append(currentBatch, item)
currentTokens += itemTokens
// 最后一批
if i == len(userForm)-1 && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
}
}
return batches
}
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
func getItemText(item map[string]any) string {
if len(item) == 0 {
return ""
}
var parts []string
for key, value := range item {
// 跳过分批时添加的元数据字段
if key == "batch_index" || key == "total_batches" || key == "item_count" {
continue
}
parts = append(parts, fmt.Sprintf("%v", value))
}
return strings.Join(parts, "\n")
}