refactor(service): 重构服务代码结构并更新配置

This commit is contained in:
2026-05-18 19:19:17 +08:00
parent 5f98e52b34
commit c49144794d
35 changed files with 1281 additions and 1162 deletions

View File

@@ -0,0 +1,112 @@
package prompt
import (
"context"
"errors"
"fmt"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto/prompt"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
ir := NewPromptIR()
// 1. 统一 Prompt IR
switch req.BuildType {
case 1: //构建提示词请求
ir.AddSystem(promptBuild(ctx, req, model))
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
case 2: //构建节点请求
ir.AddUser(NodeBuild(ctx, req))
default:
return nil, errors.New("不支持的构建类型")
}
// 2. 获取协议配置
protocol, err := GetProtocolByProvider(ctx, "qwen")
if err != nil {
return nil, err
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
// 3. 编译为 Provider Request
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, err
}
// 4. 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuild 构建系统提示词
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "qwen",
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormFullText)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
}
// 构建用户提示词
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName, // 请求模型名称
"promptInfo": prompt, // 数据库提示信息
"form": req.Form, // 系统表单
"userForm": req.UserForm, // 用户表单
"userFiles": req.UserFiles, //文件url
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如xmljson,yaml
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
}
return util.MustMarshal(payload)
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.FormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -0,0 +1,440 @@
package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/dao"
"prompts-core/model/entity"
"strings"
"time"
"prompts-core/common/util"
"prompts-core/consts/public"
promptDto "prompts-core/model/dto/prompt"
"prompts-core/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
var (
epicycleId int64
taskID string
history []map[string]any
message map[string]any
err error
taskRecord *entity.ComposeTask
)
// 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 根据构建类型进行判断处理
switch req.BuildType {
//提示词构建
case 1:
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
//1. 获取历史会话
history, err = GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil // 出错就用空的,不影响主流程
}
// 重试循环
for attempt := 0; attempt <= 0; attempt++ {
if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
}
// 2. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue
}
// 3. 保存记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
if err != nil {
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue
}
// 4. 等待结果
taskRecord, err = waitForResult(ctx, taskID)
if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue
}
// 校验结果
message = parsePromptBuild(taskRecord, chatModel)
if message != nil && util.IsMessageValid(message) {
break
}
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
message = nil
}
if message == nil {
return nil, errors.New("推理模型调用失败,请稍后再试")
}
//5.创建会话记录
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: message,
})
//节点构建
case 2:
//1. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, err
}
//2. 保存相关记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
//5. 等待结果
taskRecord, err := waitForResult(ctx, taskID)
if err != nil {
return nil, err
}
message = parseNodeBuild(taskRecord)
default:
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
return &promptDto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
}
return &promptDto.ComposeMessagesRes{
Messages: message,
EpicycleId: epicycleId,
}, nil
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, err
}
// 1. 获取当前用户的会话模型
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
// 2. 获取要构建的模型信息
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
if aiModel == nil {
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
}
return chatModel, aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
// 构建推理模型请求
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
// 创建网关任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
}
return taskID, nil
}
// waitForResult 等待结果
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout)
for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 1. 查数据库
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
}
}
// 2. 查网关状态
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
}
}
}
// parsePromptBuild 解析提示词构建结果BuildType == 1
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
if taskRecord == nil {
return nil
}
// 1. 解析 Messages
var mapped map[string]any
switch v := taskRecord.Messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &mapped)
}
case string:
json.Unmarshal([]byte(v), &mapped)
case map[string]any:
mapped = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &mapped)
}
// 2. 解析模型 ResponseMapping 获取 content 字段名
contentField := "content" // 默认值
if model != nil {
var respMapping map[string]string
switch v := model.ResponseMapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &respMapping)
}
case string:
json.Unmarshal([]byte(v), &respMapping)
case map[string]interface{}:
respMapping = make(map[string]string)
for k, val := range v {
if s, ok := val.(string); ok {
respMapping[k] = s
}
}
}
// 从映射中找到 content 对应的字段名
for k, v := range respMapping {
if strings.Contains(v, "content") {
contentField = k
break
}
}
}
// 3. 提取 content 的值
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return mapped
}
// 4. 解析 content 内的 JSON
var innerData map[string]any
json.Unmarshal([]byte(contentStr), &innerData)
return innerData
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
if taskRecord == nil {
return nil
}
var result map[string]any
switch v := taskRecord.Messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &result)
}
case string:
json.Unmarshal([]byte(v), &result)
case map[string]any:
result = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &result)
}
return result
}
// Callback 回调处理
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return err
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := util.ParseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, err
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
return &promptDto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}

View File

@@ -0,0 +1,261 @@
package prompt
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"prompts-core/common/util"
"prompts-core/service/gateway"
"github.com/gogf/gf/v2/frame/g"
)
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string)
if len(urls) == 0 {
return result
}
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
}
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
if url == "" {
continue
}
if util.IsBannedExtension(url) {
continue
}
if util.IsZipExtension(url) {
zipTexts := fetchZipFileTexts(ctx, client, url)
for k, v := range zipTexts {
result[k] = v
}
continue
}
text, err := fetchFileContent(ctx, client, url)
if err != nil {
continue
}
if text == "" {
continue
}
text = util.CleanSymbols(text)
result[url] = text
}
return result
}
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
zipBytes, err := downloadFile(client, url,
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
if err != nil {
return result
}
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
for _, file := range reader.File {
if file.FileInfo().IsDir() {
continue
}
fileName := file.Name
if util.IsBannedExtension(fileName) {
continue
}
if util.IsZipExtension(fileName) {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
contentType := http.DetectContentType(content)
if !util.IsReadableContentType(contentType) {
continue
}
text := util.CleanSymbols(string(content))
if text == "" {
continue
}
key := url + "::" + fileName
result[key] = text
}
return result
}
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
}
// fetchFileContent 获取单个文本文件内容
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("unreadable content-type: %s", contentType)
}
body, err := io.ReadAll(
io.LimitReader(resp.Body,
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
),
)
if err != nil {
return "", err
}
return strings.TrimSpace(string(body)), nil
}
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
func SkillMdContent(ctx context.Context, skillName string) string {
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
return ""
}
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
}
zipBytes, err := downloadFile(client, fullUrl,
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
if err != nil {
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil {
return ""
}
if len(mdContents) == 0 {
return ""
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
}
for fileName, content := range mdContents {
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
builder.WriteString(content)
builder.WriteString("\n\n---\n\n")
}
return strings.TrimSpace(builder.String())
}
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
result := make(map[string]string)
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, err
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
for _, file := range reader.File {
if file.FileInfo().IsDir() {
continue
}
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
if len(content) > 0 {
result[file.Name] = strings.TrimSpace(string(content))
}
}
return result, nil
}

View File

@@ -0,0 +1,264 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
}
// Segment 消息片段
type Segment struct {
Type string `json:"type"` // text/image
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
}
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
return ir
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
return ir
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
return ir
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
// 1. 系统消息
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
// 2. 历史消息
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
// 3. 用户消息
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
// GetProtocolByProvider 根据 provider_name 获取协议配置
func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) {
entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: providerName,
Status: 1,
})
if err != nil || entity == nil {
return nil, err
}
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
return parseProtocol(entity), nil
}
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
}
// MergeOrder: any → []string
if e.MergeOrder != nil {
b, _ := json.Marshal(e.MergeOrder)
json.Unmarshal(b, &p.MergeOrder)
}
// RoleMapping: any → map[string]string
if e.RoleMapping != nil {
b, _ := json.Marshal(e.RoleMapping)
json.Unmarshal(b, &p.RoleMapping)
}
// ContentMapping: any → ContentMapping
if e.ContentMapping != nil {
b, _ := json.Marshal(e.ContentMapping)
json.Unmarshal(b, &p.ContentMapping)
}
// RequestTemplate: any → map[string]any
if e.RequestTemplate != nil {
b, _ := json.Marshal(e.RequestTemplate)
json.Unmarshal(b, &p.RequestTemplate)
}
fmt.Printf("parseProtocol: %+v\n", p)
return p
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"` // direct/parts
Field string `json:"field"` // content/text
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
// 1. 按 merge_order 拼接消息
messages := mergeByOrder(ir, p.MergeOrder)
// 2. 角色映射
messages = mapRoles(messages, p.RoleMapping)
// 3. 内容字段映射
messages = mapContent(messages, p.ContentMapping)
// 4. 按 target_field + request_template 构建请求体
return buildRequest(messages, p, chatModel), nil
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
}
return messages
}
// mapRoles 角色映射
func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any {
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
// Gemini 格式: {"parts": [{"text": "..."}]}
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
// direct: {"content": "..."}
msg[cm.Field] = content
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
// 替换 {{model}}
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// 替换 {{messages}}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}

View File

@@ -0,0 +1,114 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// ==================== Redis 操作 ====================
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := fmt.Sprintf("chat:session:%s", sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := fmt.Sprintf("chat:session:%s", sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
}
var sessions []map[string]any
values := result.Strings()
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
// 反转Redis 最新在前 → 时间正序)
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
return []map[string]any{}, nil
}
var messages []map[string]any
for _, round := range historyData {
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
for _, m := range reqMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
for _, m := range respMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
}
return messages, nil
}

View File

@@ -0,0 +1,114 @@
package prompt
import (
"context"
"fmt"
sessionDao "prompts-core/dao"
"prompts-core/model/entity"
"prompts-core/common/util"
sessionDto "prompts-core/model/dto/prompt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
// 1. 解析AI返回的文本
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 2. 更新数据库
result["role"] = "assistant"
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 3. 获取当前轮次完整数据
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 4. 转换 json 并存入 Redis
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return nil, err
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &sessionDto.SessionCallbackRes{}, nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
// 2. Redis 没有 → fallback DB
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
var messages []map[string]any
for _, session := range sessions {
// request
reqMsgs := util.ConvertToMessages(session.RequestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
messages = append(messages, m)
}
}
// response
respMsgs := util.ConvertToMessages(session.ResponseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
messages = append(messages, m)
}
}
// 3. 回写 Redis
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
return messages, nil
}