refactor(prompt): 重构提示词构建服务与数据模型

This commit is contained in:
2026-05-20 11:36:39 +08:00
parent c49144794d
commit 35bc3bd6ec
24 changed files with 1682 additions and 759 deletions

View File

@@ -1,30 +1,54 @@
# prompts-core提示词服务[2026.5.12前,暂时弃置]
# Prompts-Core 提示词核心服务
## 项目简介
Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。
## 1. 功能范围(当前阶段)
- 仅做提示词配置的基础 CRUD最小可用版本
- 表:`prompts_model_prompt`
## 核心功能
1. **提示词构建引擎**
支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。
2. **智能会话管理**
基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。
3. **多模型协议适配**
动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。
4. **文件与技能集成**
自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。
5. **异步任务调度**
支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。
## 2. 接口
> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。
## 技术架构
- 开发语言Go 1.26.0
- Web 框架GoFrame v2.10.0
- 核心存储Redis会话缓存
- 服务组件Consul服务注册、Jaeger链路追踪
- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型
- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema``modelName` 作为实际调用的网关模型;结合前端 `form(role/value)``userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages`
- `GET /composeMessagesCallback/prompts-core``model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成)
- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果
- `POST /createPrompt`:创建(默认启用)
- `PUT /updatePrompt`:更新
- `DELETE /deletePrompt`:删除
- `GET /getPrompt`:详情
- `POST /listPrompt`:列表分页
## 快速开始
### 环境要求
Go 1.26+、Redis、已部署模型网关服务
## 3. 数据库初始化
执行根目录 `update.sql`
### 启动步骤
1. 克隆项目代码
2. 完成项目配置文件修改
3. 执行命令启动服务:
```bash
go run main.go
```
## 4. 运行配置
配置文件:`config.yml`
## API 接口
### 基础信息
- 服务地址:`http://{host}:3009`
- 请求类型:`application/json`
- 认证方式:请求头携带 `Authorization``X-User`
### 新增说明
- `prompts_model_prompt` 去除了 `limit_length`
- 新增 `response_json_schema`
- 新增任务记录表 `prompts_compose_task`
- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback`
- `model-gateway` 实际回调地址`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core`
### 核心接口
1. **提示词拼接接口**
- 地址:`POST /composeMessages`
- 功能:构建提示词并调用模型服务,同步返回结果
2. **任务状态查询**
- 地址:`GET /getComposeTask`
- 功能:根据任务 ID 查询处理状态与结果
3. **任务回调接口**
- 地址:`GET /composeMessagesCallback/prompts-core`
- 功能:接收模型服务处理完成回调
4. **会话同步接口**
- 地址:`POST /sessionCallback`
- 功能:同步更新会话上下文历史

View File

@@ -8,11 +8,12 @@ import (
)
// GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
func GetModelPrompt(ctx context.Context, modelType int) string {
key := "modelPrompts.types." + gconv.String(modelType)
return g.Cfg().MustGet(ctx, key, "").String()
}
// GetBuildPrompt 获取构建提示词
func GetBuildPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
// GetBuildPrompt 获取节点构建提示词
func GetBuildPrompt(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
}

View File

@@ -6,8 +6,9 @@ import (
"strings"
)
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
var AllowedMIMEPrefixes = []string{
var (
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
AllowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
@@ -20,10 +21,10 @@ var AllowedMIMEPrefixes = []string{
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
}
// BannedExtensions 禁止的文件扩展名
var BannedExtensions = map[string]bool{
// BannedExtensions 禁止的文件扩展名
BannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
@@ -35,9 +36,11 @@ var BannedExtensions = map[string]bool{
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
}
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
multiNewlines = regexp.MustCompile(`\n{3,}`)
)
// SanitizeURL 清洗 URL 字符串
func SanitizeURL(raw string) string {
@@ -51,25 +54,19 @@ func CleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
text = multiNewlines.ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
// IsBannedExtension 判断是否为禁止的文件扩展名
func IsBannedExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
ext := extractExtension(url)
return BannedExtensions[ext]
}
// IsZipExtension 判断是否为 zip 文件
func IsZipExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
ext := extractExtension(url)
return ext == ".zip"
}
@@ -78,11 +75,22 @@ func IsReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range AllowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}
// extractExtension 提取文件扩展名并清理查询参数
func extractExtension(url string) string {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext
}

View File

@@ -10,6 +10,7 @@ import (
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
func AsyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
@@ -18,9 +19,11 @@ func AsyncCtx(ctx context.Context) context.Context {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
@@ -28,25 +31,37 @@ func AsyncCtx(ctx context.Context) context.Context {
func ForwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
setHeaderFromContext(headers, ctx, "Authorization", "token")
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
fallbackToRequestHeaders(headers, ctx)
return headers
}
// setHeaderFromContext 从上下文中设置 header
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
headers[headerKey] = value
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// fallbackToRequestHeaders 从请求头中获取作为兜底
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
r := g.RequestFromCtx(ctx)
if r == nil {
return
}
// 兜底:从请求头获取
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

View File

@@ -15,6 +15,7 @@ func ParseOutput(text string) (map[string]any, error) {
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
@@ -23,26 +24,17 @@ func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
// 如果有 messages 字段,直接返回
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
// 否则当成单条 message
return []map[string]any{
j.Map(),
}
}
// IsMessageValid 校验推理结果是否合法
func IsMessageValid(message map[string]any) bool {
if message == nil {
return false
}
return true
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
@@ -50,6 +42,17 @@ func FormToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
@@ -60,39 +63,16 @@ func MustMarshal(v any) string {
if err != nil {
return "{}"
}
return string(b)
}
// ParseJSONField 解析 JSON 字段
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
// 处理 *gvar.Var 类型
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
// 统一转 map 再美化
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
@@ -101,3 +81,71 @@ func JSONPretty(v any) string {
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
func GvarToMap(v *gvar.Var) map[string]any {
if v == nil || v.IsNil() {
return nil
}
result := make(map[string]any)
// 方法1尝试获取 map 值
if m := v.Map(); len(m) > 0 {
return m
}
// 方法2尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), &result)
if len(result) > 0 {
return result
}
}
// 方法3尝试获取 interface 再转换
if val := v.Val(); val != nil {
switch val.(type) {
case map[string]any:
return val.(map[string]any)
default:
data, _ := json.Marshal(val)
json.Unmarshal(data, &result)
}
}
return result
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}

229
common/util/token.go Normal file
View File

@@ -0,0 +1,229 @@
package util
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/gogf/gf/v2/container/gvar"
)
var (
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
punctRegex = regexp.MustCompile(`[[:punct:]]`)
)
// TokenConfig Token计算配置
type TokenConfig struct {
ZhRatio float64 `json:"zh_ratio"`
EnRatio float64 `json:"en_ratio"`
SpaceRatio float64 `json:"space_ratio"`
PunctuationRatio float64 `json:"punctuation_ratio"`
MaxWindowSize int `json:"max_window_size"`
ReserveRatio float64 `json:"reserve_ratio"`
MinReserve int `json:"min_reserve"`
}
// CalculateTokens 计算文本token数
func CalculateTokens(text string, tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 0
}
if text == "" {
return 0
}
zhCount := countChineseChars(text)
enCount := countEnglishWords(text)
spaceCount := strings.Count(text, " ")
punctCount := countPunctuation(text)
totalTokens := int(
float64(zhCount)*config.ZhRatio +
float64(enCount)*config.EnRatio +
float64(spaceCount)*config.SpaceRatio +
float64(punctCount)*config.PunctuationRatio,
)
return totalTokens
}
// CountToken 计算token是否超出窗口限制
// 返回: true - 未超出(可用), false - 已超出(不可用)
func CountToken(text string, tokenConfig any) bool {
config := parseConfig(tokenConfig)
if config == nil {
return false
}
estimatedTokens := CalculateTokens(text, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
return estimatedTokens <= availableWindow
}
// GetAvailableWindow 获取可用窗口大小
func GetAvailableWindow(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
reserve := reserveByRatio
if config.MinReserve > reserve {
reserve = config.MinReserve
}
available := config.MaxWindowSize - reserve
if available < 0 {
available = 0
}
return available
}
// GetMaxWindowSize 获取模型最大窗口大小
func GetMaxWindowSize(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
return config.MaxWindowSize
}
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
// 返回: isValid, exceedTokens, error
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return true, 0, nil
}
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
if totalTokens > availableWindow {
return false, totalTokens - availableWindow, nil
}
return true, 0, nil
}
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
// 返回: 需要拆分的批次数, 每批的token数, 错误
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return 1, nil, nil
}
availableWindow := GetAvailableWindow(tokenConfig)
batches := 1
currentTokens := 0
batchTokens := make([]int, 0)
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
itemTokens := CalculateTokens(itemStr, tokenConfig)
if currentTokens+itemTokens > availableWindow {
batchTokens = append(batchTokens, currentTokens)
batches++
currentTokens = itemTokens
} else {
currentTokens += itemTokens
}
}
if currentTokens > 0 {
batchTokens = append(batchTokens, currentTokens)
}
return batches, batchTokens, nil
}
// parseConfig 解析配置
func parseConfig(tokenConfig any) *TokenConfig {
if tokenConfig == nil {
return nil
}
switch v := tokenConfig.(type) {
case *gvar.Var:
return parseGVarConfig(v)
case map[string]any:
return parseMapConfig(v)
case *TokenConfig:
return v
case TokenConfig:
return &v
default:
return nil
}
}
// parseGVarConfig 解析 GVar 配置
func parseGVarConfig(v *gvar.Var) *TokenConfig {
if v.IsNil() {
return nil
}
mapVal := v.Map()
if mapVal == nil {
return nil
}
config := &TokenConfig{}
data, _ := json.Marshal(mapVal)
json.Unmarshal(data, config)
return config
}
// parseMapConfig 解析 Map 配置
func parseMapConfig(v map[string]any) *TokenConfig {
config := &TokenConfig{}
data, _ := json.Marshal(v)
json.Unmarshal(data, config)
return config
}
// countChineseChars 统计中文字符数量
func countChineseChars(text string) int {
count := 0
for _, r := range text {
if unicode.Is(unicode.Han, r) {
count++
}
}
return count
}
// countEnglishWords 统计英文单词数量
func countEnglishWords(text string) int {
return len(enWordRegex.FindAllString(text, -1))
}
// countPunctuation 统计标点符号数量
func countPunctuation(text string) int {
return len(punctRegex.FindAllString(text, -1))
}
// calculateUserFormTokens 计算 UserForm 总 token 数
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
totalTokens := 0
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
totalTokens += CalculateTokens(itemStr, tokenConfig)
}
return totalTokens
}

View File

@@ -103,49 +103,7 @@ modelPrompts:
在执行多模态任务时你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性避免出现跨模态语义断裂或输出不一致的问题。
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
buildProject:
types:
1: |
你是专业的JSON结构生成专家必须严格遵守以下全部规则。
【强制规则】
必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
完整读取UserForm所有字段不得忽略任何字段
如果有skill相关内容必须完整的将内容拼接到system角色描述中
理解全部语义后再输出,禁止断章取义;
UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
【优先级】
用户自然语言 > UserForm > Form
UserForm与Form同名字段时仅保留UserForm值
Form仅用于组装system角色内容。
【表单处理】
Form系统提示词、默认参数、基础配置 → 专属填充system角色
UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和示例10条文案各配5张图 → 总50张parameters.n=50,用户没有相关数量必须默认1
图片尺寸为空时自动填充size=1024*1024。
【结构铁律】
严格沿用固定输出结构,不增删字段或修改层级;
messages元素必须按结构返回
禁止将role对象转为字符串、禁止嵌套错乱
输出纯净JSON无多余转义符、无换行符、无额外字符
所有括号、引号必须成对闭合保证JSON合法。
【参数赋值】
model固定沿用传入值
返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
history历史信息必须结合UserForm里的内容对用户描述部分进行补充
从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
【输出要求】
仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
完整合UserForm全部字段语义到user描述
生成后自检JSON语法、结构、数量错误则自动重新生成。
【输出结构】
%s
【字段映射】
%s
【完整输入信息】
%s
直接输出最终JSON
2: |
nodePrompts: |
你是流程路由助手你的任务是根据上下文选择一个正确的节点ID返回。
规则:
1. 只允许从下面的可选节点ID列表中选择一个返回
@@ -155,3 +113,41 @@ buildProject:
%s
上下文内容:
%s
#你是专业的JSON结构生成专家必须严格遵守以下全部规则。
# 【强制规则】
# 必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
# 完整读取UserForm所有字段不得忽略任何字段
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中
# 理解全部语义后再输出,禁止断章取义;
# UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
# 【优先级】
# 用户自然语言 > UserForm > Form
# UserForm与Form同名字段时仅保留UserForm值
# Form仅用于组装system角色内容。
# 【表单处理】
# Form系统提示词、默认参数、基础配置 → 专属填充system角色
# UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
# 自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和用户没有相关数量必须默认1
# 图片尺寸为空时自动填充size=1024*1024。
# 【结构铁律】
# 严格沿用固定输出结构,不增删字段或修改层级;
# messages元素必须按结构返回
# 禁止将role对象转为字符串、禁止嵌套错乱
# 输出纯净JSON无多余转义符、无换行符、无额外字符
# 所有括号、引号必须成对闭合保证JSON合法。
# 【参数赋值】
# model固定沿用传入值
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充
# 从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
# 【输出要求】
# 仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
# 完整合UserForm全部字段语义到user描述
# 生成后自检JSON语法、结构、数量错误则自动重新生成。
# 【输出结构】
# %s
# 【完整输入信息】
# %s
# 直接输出最终JSON

View File

@@ -5,3 +5,8 @@ const (
ComposeStatusSuccess = "success"
ComposeStatusFailed = "failed"
)
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建
)

View File

@@ -1,9 +1,9 @@
package prompt
package controller
import (
"context"
"prompts-core/model/dto"
promptDto "prompts-core/model/dto/prompt"
promptService "prompts-core/service/prompt"
)
@@ -13,17 +13,17 @@ type prompt struct{}
var Prompt = new(prompt)
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) {
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
return promptService.ComposeMessages(ctx, req)
}
// Callback model-gateway 提示词回调
func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) {
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) {
err = promptService.Callback(ctx, req)
return
}
// GetComposeTask 查询拼接任务结果
func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) {
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId)
}

View File

@@ -1,9 +1,9 @@
package prompt
package controller
import (
"context"
"prompts-core/model/dto"
promptDto "prompts-core/model/dto/prompt"
promptService "prompts-core/service/prompt"
)
@@ -13,6 +13,6 @@ type session struct{}
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) {
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return promptService.SessionCallback(ctx, req)
}

View File

@@ -15,7 +15,7 @@ type composeSessionDao struct{}
// Insert 插入
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
var m = new(entity.ComposeTask)
var m = new(entity.ComposeSession)
err = gconv.Struct(req, &m)
if err != nil {
return

View File

@@ -6,6 +6,7 @@ import (
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ProviderProtocol = &providerProtocolDao{}
@@ -14,7 +15,13 @@ type providerProtocolDao struct{}
// Insert 新增协议配置
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert()
var m = new(entity.ProviderProtocol)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Insert(m)
if err != nil {
return 0, err
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"os"
"os/signal"
"prompts-core/controller/prompt"
"prompts-core/controller"
"syscall"
"gitea.com/red-future/common/http"
@@ -21,8 +21,8 @@ func main() {
defer jaeger.ShutDown(ctx)
// 注册路由
http.RouteRegister([]interface{}{
prompt.Prompt,
prompt.Session,
controller.Prompt,
controller.Session,
})
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server

View File

@@ -1,4 +1,4 @@
package prompt
package dto
import "github.com/gogf/gf/v2/frame/g"
@@ -9,16 +9,22 @@ type ComposeMessagesReq struct {
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
Form map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
}
type ComposeMessagesRes struct {
Messages any `json:"messages,omitempty" dc:"最终消息数组"`
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
}
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []any `json:"rounds"` // 每轮详情(动态类型)
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`

View File

@@ -1,4 +1,4 @@
package prompt
package dto
import "github.com/gogf/gf/v2/frame/g"

View File

@@ -16,10 +16,10 @@ type AsynchModel struct {
ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled int `orm:"enabled" json:"enabled"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
@@ -28,6 +28,9 @@ type AsynchModel struct {
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig any `orm:"token_config" json:"tokenConfig"`
}
type asynchModelCol struct {
@@ -55,6 +58,9 @@ type asynchModelCol struct {
RetryQueueMaxSecs string
AutoCleanSeconds string
Remark string
IsOwner string
OperatorName string
TokenConfig string
}
var AsynchModelCol = asynchModelCol{
@@ -82,4 +88,7 @@ var AsynchModelCol = asynchModelCol{
RetryQueueMaxSecs: "retry_queue_max_seconds",
AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
}

View File

@@ -4,23 +4,41 @@ import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto/prompt"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
// 1. 统一 Prompt IR
switch req.BuildType {
case 1: //构建提示词请求
ir.AddSystem(promptBuild(ctx, req, model))
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
@@ -28,41 +46,71 @@ func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq,
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
case 2: //构建节点请求
ir.AddUser(NodeBuild(ctx, req))
default:
return nil, errors.New("不支持的构建类型")
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
// 2. 获取协议配置
protocol, err := GetProtocolByProvider(ctx, "qwen")
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
if err != nil {
return nil, err
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
// 3. 编译为 Provider Request
providerReq, err := Compile(ir, protocol, chatModel)
providerReq, err := Compile(ir, protocol, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("编译请求失败: %w", err)
}
// 4. 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"modelName": req.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuild 构建系统提示词
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, model)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
return map[string]any{
"modelName": model.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "qwen",
ProviderName: model.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
@@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent
}
outputJSON := util.JSONPretty(model.RequestMapping)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormFullText)
`, util.FormToJSON(req.Form), userFormContent)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName,
maxWindowSize,
availableWindow,
totalRounds,
totalRounds,
totalRounds,
outputJSON,
inputInfo,
totalRounds,
)
}
// 构建用户提示词
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName, // 请求模型名称
"promptInfo": prompt, // 数据库提示信息
"form": req.Form, // 系统表单
"userForm": req.UserForm, // 用户表单
"userFiles": req.UserFiles, //文件url
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如xmljson,yaml
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
}
return util.MustMarshal(payload)
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(text)
}
return builder.String()
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.FormToJSON(req.UserForm)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -5,171 +5,229 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/dao"
"prompts-core/model/entity"
"strings"
"time"
"prompts-core/common/util"
"prompts-core/consts/public"
promptDto "prompts-core/model/dto/prompt"
"prompts-core/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
var (
epicycleId int64
taskID string
history []map[string]any
message map[string]any
err error
taskRecord *entity.ComposeTask
)
// 获取模型信息
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 根据构建类型进行判断处理
if err = validateUserForm(ctx, req, aiModel); err != nil {
return nil, err
}
switch req.BuildType {
//提示词构建
case 1:
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return handleDefaultCase(ctx, req)
}
}
// validateUserForm 校验用户表单
func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
//1. 获取历史会话
history, err = GetHistoryMessages(ctx, req.SessionId)
history, err := GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil // 出错就用空的,不影响主流程
history = nil
}
// 重试循环
var message *dto.MultiRoundResult
var taskRecord *entity.ComposeTask
for attempt := 0; attempt <= 0; attempt++ {
if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
}
// 2. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue
}
// 3. 保存记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
if err != nil {
if err = saveComposeTask(ctx, taskID, req); err != nil {
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue
}
// 4. 等待结果
taskRecord, err = waitForResult(ctx, taskID)
if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue
}
// 校验结果
message = parsePromptBuild(taskRecord, chatModel)
if message != nil && util.IsMessageValid(message) {
if message != nil {
break
}
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
message = nil
}
if message == nil {
return nil, errors.New("推理模型调用失败,请稍后再试")
}
//5.创建会话记录
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: message,
})
//节点构建
case 2:
//1. 调用推理模型
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, err
g.Log().Errorf(ctx, "创建会话记录失败: %v", err)
}
//2. 保存相关记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
return &dto.ComposeMessagesRes{
Messages: message,
EpicycleId: epicycleId,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
taskRecord, err := waitForResult(ctx, taskID)
if err != nil {
return nil, fmt.Errorf("等待结果失败: %w", err)
}
message := parseNodeBuild(taskRecord)
return &dto.ComposeMessagesRes{
Messages: message,
EpicycleId: 0,
}, nil
}
// handleDefaultCase 处理默认情况
func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
if err != nil {
return nil, fmt.Errorf("创建会话记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
}
// saveComposeTask 保存组合任务
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
//5. 等待结果
taskRecord, err := waitForResult(ctx, taskID)
if err != nil {
return nil, err
}
message = parseNodeBuild(taskRecord)
default:
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
return &promptDto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
}
return &promptDto.ComposeMessagesRes{
Messages: message,
EpicycleId: epicycleId,
}, nil
return err
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
// 1. 获取当前用户的会话模型
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
chatModel, err := getChatModel(ctx, userInfo.UserName)
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
// 2. 获取要构建的模型信息
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
}
if aiModel == nil {
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
}
return chatModel, aiModel, nil
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
// 构建推理模型请求
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
// 创建网关任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
@@ -186,96 +244,131 @@ func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq,
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 1. 查数据库
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
if completed, result := checkTaskCompletion(record); completed {
return result, nil
}
}
// 2. 查网关状态
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
if err = syncGatewayTaskState(ctx, taskID, record); err != nil {
g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err)
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
case <-ticker.C:
}
}
}
// checkTaskCompletion 检查任务是否完成
func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) {
if record == nil {
return false, nil
}
switch record.Status {
case public.ComposeStatusSuccess:
return true, record
case public.ComposeStatusFailed:
errMsg := strings.TrimSpace(record.ErrorMessage)
if errMsg == "" {
return true, nil
}
return true, nil
default:
return false, nil
}
}
// syncGatewayTaskState 同步网关任务状态
func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error {
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
return fmt.Errorf("查询网关状态失败: %w", err)
}
switch state {
case 2:
return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "")
case 3:
updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败")
return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
return nil
}
// updateTaskStatus 更新任务状态
func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error {
task := &entity.ComposeTask{
TaskId: taskID,
Status: status,
}
if errorMsg != "" {
task.ErrorMessage = errorMsg
}
_, err := dao.ComposeTask.Update(ctx, task)
return err
}
// parsePromptBuild 解析提示词构建结果BuildType == 1
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult {
if taskRecord == nil {
return nil
}
mapped := parseTaskMessages(taskRecord.Messages)
if mapped == nil {
return createDefaultResult(nil)
}
// 1. 解析 Messages
contentField := getContentField(model)
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return createDefaultResult(mapped)
}
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
}
}
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{singleRound},
}
}
return createDefaultResult(map[string]any{"content": contentStr})
}
// parseTaskMessages 解析任务消息
func parseTaskMessages(messages any) map[string]any {
var mapped map[string]any
switch v := taskRecord.Messages.(type) {
switch v := messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &mapped)
@@ -289,115 +382,137 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
json.Unmarshal(b, &mapped)
}
// 2. 解析模型 ResponseMapping 获取 content 字段名
contentField := "content" // 默认值
if model != nil {
var respMapping map[string]string
switch v := model.ResponseMapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &respMapping)
}
case string:
json.Unmarshal([]byte(v), &respMapping)
case map[string]interface{}:
respMapping = make(map[string]string)
for k, val := range v {
if s, ok := val.(string); ok {
respMapping[k] = s
}
}
}
// 从映射中找到 content 对应的字段名
for k, v := range respMapping {
if strings.Contains(v, "content") {
contentField = k
break
}
}
}
// 3. 提取 content 的值
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return mapped
}
// 4. 解析 content 内的 JSON
var innerData map[string]any
json.Unmarshal([]byte(contentStr), &innerData)
return innerData
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
if taskRecord == nil {
// tryParseAsArray 尝试将字符串解析为数组
func tryParseAsArray(contentStr string) []any {
var roundsArray []any
if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil {
return nil
}
var result map[string]any
switch v := taskRecord.Messages.(type) {
return roundsArray
}
// tryParseAsObject 尝试将字符串解析为对象
func tryParseAsObject(contentStr string) any {
var singleRound any
if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil {
return nil
}
return singleRound
}
// createDefaultResult 创建默认结果
func createDefaultResult(data any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{data},
}
}
// getContentField 从模型 ResponseMapping 中获取 content 字段名
func getContentField(model *entity.AsynchModel) string {
if model == nil {
return "content"
}
respMapping := parseResponseMapping(model.ResponseMapping)
for k, v := range respMapping {
if strings.Contains(v, "content") {
return k
}
}
return "content"
}
// parseResponseMapping 解析响应映射
func parseResponseMapping(mapping any) map[string]string {
result := make(map[string]string)
switch v := mapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &result)
}
case string:
json.Unmarshal([]byte(v), &result)
case map[string]any:
result = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &result)
case map[string]interface{}:
for k, val := range v {
if s, ok := val.(string); ok {
result[k] = s
}
}
}
return result
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
if taskRecord == nil {
return nil
}
result := parseTaskMessages(taskRecord.Messages)
if result == nil {
result = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{result},
}
}
// Callback 回调处理
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return err
return fmt.Errorf("查询任务失败: %w", err)
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := util.ParseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
}
return handleCallbackSuccess(ctx, req)
}
// handleCallbackFailure 处理回调失败
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: errorMsg,
})
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
result, err := util.ParseOutput(req.Text)
if err != nil {
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
return fmt.Errorf("解析模型输出失败: %w", err)
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
@@ -407,34 +522,43 @@ func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, err
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
messages := parseMessagesForResponse(record.Messages)
return &promptDto.GetComposeTaskRes{
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}

View File

@@ -10,10 +10,15 @@ import (
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/service/gateway"
)
"github.com/gogf/gf/v2/frame/g"
const (
bytesPerKB = 1024
bytesPerMB = 1024 * 1024
)
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
@@ -24,51 +29,49 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
return result
}
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
if url == "" {
continue
}
if util.IsBannedExtension(url) {
if url == "" || util.IsBannedExtension(url) {
continue
}
if util.IsZipExtension(url) {
zipTexts := fetchZipFileTexts(ctx, client, url)
for k, v := range zipTexts {
result[k] = v
}
mergeMap(result, fetchZipFileTexts(ctx, client, url))
continue
}
text, err := fetchFileContent(ctx, client, url)
if err != nil {
continue
}
if text == "" {
continue
}
text = util.CleanSymbols(text)
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text
}
}
return result
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
}
// fetchAndCleanFileContent 获取并清理文件内容
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
text, err := fetchFileContent(ctx, client, url)
if err != nil || text == "" {
return ""
}
return util.CleanSymbols(text)
}
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
zipBytes, err := downloadFile(client, url,
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, url, maxSize)
if err != nil {
return result
}
@@ -78,61 +81,61 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() {
if shouldSkipZipEntry(file.Name) {
continue
}
fileName := file.Name
if util.IsBannedExtension(fileName) {
continue
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
result[url+"::"+file.Name] = text
}
}
if util.IsZipExtension(fileName) {
continue
}
return result
}
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
func shouldSkipZipEntry(fileName string) bool {
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
}
// extractZipEntryContent 提取 zip 条目内容
func extractZipEntryContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
continue
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
continue
return ""
}
contentType := http.DetectContentType(content)
if !util.IsReadableContentType(contentType) {
continue
if !util.IsReadableContentType(http.DetectContentType(content)) {
return ""
}
text := util.CleanSymbols(string(content))
if text == "" {
continue
return ""
}
key := url + "::" + fileName
result[key] = text
}
return result
return text
}
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
return nil, fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
@@ -140,19 +143,24 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
return body, nil
}
// fetchFileContent 获取单个文本文件内容
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", err
return "", fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
@@ -162,16 +170,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
contentType := resp.Header.Get("Content-Type")
if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("unreadable content-type: %s", contentType)
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
}
body, err := io.ReadAll(
io.LimitReader(resp.Body,
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
),
)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return "", err
return "", fmt.Errorf("读取响应失败: %w", err)
}
return strings.TrimSpace(string(body)), nil
@@ -186,27 +191,26 @@ func SkillMdContent(ctx context.Context, skillName string) string {
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
}
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, fullUrl,
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
zipBytes, err := downloadFile(client, fullUrl, maxSize)
if err != nil {
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil {
if err != nil || len(mdContents) == 0 {
return ""
}
if len(mdContents) == 0 {
return ""
}
return buildSkillMarkdown(skillResp, mdContents)
}
// buildSkillMarkdown 构建技能 Markdown 内容
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
@@ -227,35 +231,53 @@ func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, er
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, err
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() {
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
continue
}
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
if len(content) > 0 {
result[file.Name] = strings.TrimSpace(string(content))
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
result[file.Name] = content
}
}
return result, nil
}
// isMarkdownFile 判断是否为 Markdown 文件
func isMarkdownFile(fileName string) bool {
return strings.HasSuffix(strings.ToLower(fileName), ".md")
}
// readMarkdownFileContent 读取 Markdown 文件内容
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if len(content) == 0 {
return ""
}
return strings.TrimSpace(string(content))
}
// createHTTPClient 创建 HTTP 客户端
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
return &http.Client{
Timeout: timeout,
}
}

View File

@@ -0,0 +1,75 @@
# Prompts-Core提示词核心服务
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
---
## 项目简介
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
### 核心价值
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
---
## 核心功能
### 1. 提示词构建引擎
#### 多模态支持
| 类型 | 说明 | 适用场景 |
|------|------|----------|
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
#### 构建模式
- **BuildType 1提示词构建**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
- **BuildType 2节点构建**:工作流路由决策,根据上下文选择节点 ID
#### 分批处理
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
### 2. 会话管理系统
- **双层存储**Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
### 3. 协议适配器
通过配置动态支持多种模型协议:
- 角色映射system/user/assistant → 目标协议角色
- 内容字段映射content → parts.text 等
- 消息顺序控制:灵活配置拼接顺序
- 请求模板渲染:支持占位符替换
### 4. 任务调度
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
- **重试机制**:可配置最大重试次数(默认 3 次)
- **超时保护**:默认 300 秒超时
---
## 技术架构
### 技术栈
| 组件 | 版本 | 用途 |
|------|------|------|
| Go | 1.26.0 | 编程语言 |
| GoFrame | v2.10.0 | Web 框架 |
| PostgreSQL | - | 关系型数据库 |
| Redis | - | 缓存与会话存储 |
| Consul | - | 服务注册与发现 |
| Jaeger | - | 分布式链路追踪 |
### 架构图

View File

@@ -20,11 +20,27 @@ type PromptIR struct {
// Segment 消息片段
type Segment struct {
Type string `json:"type"` // text/image
Type string `json:"type"`
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"`
Field string `json:"field"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
@@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR {
}
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString("System: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Role)
builder.WriteString(": ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString("User: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
@@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
// 1. 系统消息
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
@@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any {
})
}
// 2. 历史消息
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
@@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any {
})
}
// 3. 用户消息
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
@@ -97,11 +159,7 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
if err != nil || entity == nil {
return nil, err
}
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
fmt.Println("entity打印", entity)
return parseProtocol(entity), nil
}
@@ -109,62 +167,27 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
}
// MergeOrder: any → []string
if e.MergeOrder != nil {
b, _ := json.Marshal(e.MergeOrder)
json.Unmarshal(b, &p.MergeOrder)
}
// RoleMapping: any → map[string]string
if e.RoleMapping != nil {
b, _ := json.Marshal(e.RoleMapping)
json.Unmarshal(b, &p.RoleMapping)
}
// ContentMapping: any → ContentMapping
if e.ContentMapping != nil {
b, _ := json.Marshal(e.ContentMapping)
json.Unmarshal(b, &p.ContentMapping)
}
// RequestTemplate: any → map[string]any
if e.RequestTemplate != nil {
b, _ := json.Marshal(e.RequestTemplate)
json.Unmarshal(b, &p.RequestTemplate)
}
fmt.Printf("parseProtocol: %+v\n", p)
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
return p
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"` // direct/parts
Field string `json:"field"` // content/text
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
// 1. 按 merge_order 拼接消息
messages := mergeByOrder(ir, p.MergeOrder)
// 2. 角色映射
messages = mapRoles(messages, p.RoleMapping)
// 3. 内容字段映射
messages = mapContent(messages, p.ContentMapping)
// 4. 按 target_field + request_template 构建请求体
return buildRequest(messages, p, chatModel), nil
}
@@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
}
}
}
return messages
}
@@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
@@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
switch cm.Type {
case "parts":
// Gemini 格式: {"parts": [{"text": "..."}]}
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
// direct: {"content": "..."}
msg[cm.Field] = content
}
}
return messages
}
@@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
@@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
b, _ := json.Marshal(tmpl)
str := string(b)
// 替换 {{model}}
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// 替换 {{messages}}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}

View File

@@ -9,15 +9,16 @@ import (
"github.com/gogf/gf/v2/frame/g"
)
// ==================== Redis 操作 ====================
const (
redisKeyPrefix = "chat:session:%s"
)
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := fmt.Sprintf("chat:session:%s", sessionId)
key := formatRedisKey(sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
data := map[string]any{
"sessionId": sessionId,
@@ -31,18 +32,29 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
return fmt.Errorf("序列化会话数据失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err
}
return nil
}
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
@@ -51,7 +63,7 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := fmt.Sprintf("chat:session:%s", sessionId)
key := formatRedisKey(sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
@@ -62,8 +74,17 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
return []map[string]any{}, nil
}
sessions := parseRedisSessions(ctx, result.Strings())
reverseSlice(sessions)
return sessions, nil
}
// parseRedisSessions 解析Redis会话数据
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
var sessions []map[string]any
values := result.Strings()
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
@@ -73,12 +94,14 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
sessions = append(sessions, data)
}
// 反转Redis 最新在前 → 时间正序)
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions
}
return sessions, nil
// reverseSlice 反转切片
func reverseSlice(s []map[string]any) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
@@ -92,23 +115,31 @@ func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map
return []map[string]any{}, nil
}
return flattenHistoryMessages(historyData), nil
}
// flattenHistoryMessages 扁平化历史消息
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
var messages []map[string]any
for _, round := range historyData {
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
for _, m := range reqMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
for _, m := range respMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
appendMessagesFromField(round, "requestContent", &messages)
appendMessagesFromField(round, "responseContent", &messages)
}
return messages, nil
return messages
}
// appendMessagesFromField 从指定字段追加消息
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]interface{})
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
*messages = append(*messages, msg)
}
}
}

View File

@@ -3,112 +3,164 @@ package prompt
import (
"context"
"fmt"
sessionDao "prompts-core/dao"
"prompts-core/model/entity"
"prompts-core/common/util"
sessionDto "prompts-core/model/dto/prompt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
// 1. 解析AI返回的文本
// SessionCallback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
// 2. 更新数据库
result["role"] = "assistant"
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
// 3. 获取当前轮次完整数据
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 4. 转换 json 并存入 Redis
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return nil, err
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &sessionDto.SessionCallbackRes{}, nil
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
ResponseContent: response,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err)
}
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
return session, nil
}
// saveSessionToRedis 保存会话到Redis
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
}
return nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
// 2. Redis 没有 → fallback DB
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
messages := extractMessagesFromSessions(sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
// request
reqMsgs := util.ConvertToMessages(session.RequestContent)
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
}
return messages
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
messages = append(messages, m)
*messages = append(*messages, m)
}
}
}
// response
respMsgs := util.ConvertToMessages(session.ResponseContent)
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
messages = append(messages, m)
}
*messages = append(*messages, m)
}
}
// 3. 回写 Redis
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
return messages, nil
}

View File

@@ -0,0 +1,135 @@
package prompt
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}
availableWindow := util.GetAvailableWindow(model.TokenConfig)
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
if len(batches) <= 1 {
return req, 1, nil
}
newUserForm := buildBatchedUserForm(batches)
newReq := *req
newReq.UserForm = newUserForm
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
len(req.UserForm), len(batches))
return &newReq, len(batches), nil
}
// buildBatchedUserForm 构建分批后的用户表单
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
newUserForm := make([]map[string]any, 0, len(batches))
for i, batch := range batches {
combinedText := combineBatchText(batch)
newUserForm = append(newUserForm, map[string]any{
"batch_index": i + 1,
"total_batches": len(batches),
"text": combinedText,
"item_count": len(batch),
})
}
return newUserForm
}
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
func combineBatchText(batch []map[string]any) string {
var builder strings.Builder
for j, item := range batch {
itemText := getItemText(item)
if itemText == "" {
continue
}
if j > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(itemText)
}
return builder.String()
}
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
if len(userForm) == 0 {
return [][]map[string]any{}
}
batches := make([][]map[string]any, 0)
currentBatch := make([]map[string]any, 0)
currentTokens := 0
for i, item := range userForm {
itemText := getItemText(item)
itemTokens := util.CalculateTokens(itemText, tokenConfig)
// 单个元素超过窗口,单独成一批
if itemTokens > maxTokens {
if len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
batches = append(batches, []map[string]any{item})
continue
}
// 判断是否需要新开一批
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
currentBatch = append(currentBatch, item)
currentTokens += itemTokens
// 最后一批
if i == len(userForm)-1 && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
}
}
return batches
}
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
func getItemText(item map[string]any) string {
if len(item) == 0 {
return ""
}
var parts []string
for key, value := range item {
// 跳过分批时添加的元数据字段
if key == "batch_index" || key == "total_batches" || key == "item_count" {
continue
}
parts = append(parts, fmt.Sprintf("%v", value))
}
return strings.Join(parts, "\n")
}