refactor(prompt): 重构提示词构建服务与数据模型
This commit is contained in:
72
README.md
72
README.md
@@ -1,30 +1,54 @@
|
||||
# prompts-core(提示词服务)[2026.5.12前,暂时弃置]
|
||||
# Prompts-Core 提示词核心服务
|
||||
## 项目简介
|
||||
Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。
|
||||
|
||||
## 1. 功能范围(当前阶段)
|
||||
- 仅做提示词配置的基础 CRUD(最小可用版本)
|
||||
- 表:`prompts_model_prompt`
|
||||
## 核心功能
|
||||
1. **提示词构建引擎**
|
||||
支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。
|
||||
2. **智能会话管理**
|
||||
基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。
|
||||
3. **多模型协议适配**
|
||||
动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。
|
||||
4. **文件与技能集成**
|
||||
自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。
|
||||
5. **异步任务调度**
|
||||
支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。
|
||||
|
||||
## 2. 接口
|
||||
> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。
|
||||
## 技术架构
|
||||
- 开发语言:Go 1.26.0
|
||||
- Web 框架:GoFrame v2.10.0
|
||||
- 核心存储:Redis(会话缓存)
|
||||
- 服务组件:Consul(服务注册)、Jaeger(链路追踪)
|
||||
- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型
|
||||
|
||||
- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema`,`modelName` 作为实际调用的网关模型;结合前端 `form(role/value)` 与 `userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages`
|
||||
- `GET /composeMessagesCallback/prompts-core`:`model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成)
|
||||
- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果
|
||||
- `POST /createPrompt`:创建(默认启用)
|
||||
- `PUT /updatePrompt`:更新
|
||||
- `DELETE /deletePrompt`:删除
|
||||
- `GET /getPrompt`:详情
|
||||
- `POST /listPrompt`:列表分页
|
||||
## 快速开始
|
||||
### 环境要求
|
||||
Go 1.26+、Redis、已部署模型网关服务
|
||||
|
||||
## 3. 数据库初始化
|
||||
执行根目录 `update.sql`。
|
||||
### 启动步骤
|
||||
1. 克隆项目代码
|
||||
2. 完成项目配置文件修改
|
||||
3. 执行命令启动服务:
|
||||
```bash
|
||||
go run main.go
|
||||
```
|
||||
|
||||
## 4. 运行配置
|
||||
配置文件:`config.yml`
|
||||
## API 接口
|
||||
### 基础信息
|
||||
- 服务地址:`http://{host}:3009`
|
||||
- 请求类型:`application/json`
|
||||
- 认证方式:请求头携带 `Authorization`、`X-User`
|
||||
|
||||
### 新增说明
|
||||
- `prompts_model_prompt` 去除了 `limit_length`
|
||||
- 新增 `response_json_schema`
|
||||
- 新增任务记录表 `prompts_compose_task`
|
||||
- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback`
|
||||
- `model-gateway` 实际回调地址为:`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core`
|
||||
### 核心接口
|
||||
1. **提示词拼接接口**
|
||||
- 地址:`POST /composeMessages`
|
||||
- 功能:构建提示词并调用模型服务,同步返回结果
|
||||
2. **任务状态查询**
|
||||
- 地址:`GET /getComposeTask`
|
||||
- 功能:根据任务 ID 查询处理状态与结果
|
||||
3. **任务回调接口**
|
||||
- 地址:`GET /composeMessagesCallback/prompts-core`
|
||||
- 功能:接收模型服务处理完成回调
|
||||
4. **会话同步接口**
|
||||
- 地址:`POST /sessionCallback`
|
||||
- 功能:同步更新会话上下文历史
|
||||
@@ -8,11 +8,12 @@ import (
|
||||
)
|
||||
|
||||
// GetModelPrompt 获取请求模型的提示词
|
||||
func GetModelPrompt(ctx context.Context, Type int) string {
|
||||
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
|
||||
func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
key := "modelPrompts.types." + gconv.String(modelType)
|
||||
return g.Cfg().MustGet(ctx, key, "").String()
|
||||
}
|
||||
|
||||
// GetBuildPrompt 获取构建提示词
|
||||
func GetBuildPrompt(ctx context.Context, Type int) string {
|
||||
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
|
||||
// GetBuildPrompt 获取节点构建提示词
|
||||
func GetBuildPrompt(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
|
||||
var AllowedMIMEPrefixes = []string{
|
||||
var (
|
||||
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
|
||||
AllowedMIMEPrefixes = []string{
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
@@ -20,10 +21,10 @@ var AllowedMIMEPrefixes = []string{
|
||||
"application/x-python",
|
||||
"application/x-perl",
|
||||
"application/x-ruby",
|
||||
}
|
||||
}
|
||||
|
||||
// BannedExtensions 禁止的文件扩展名
|
||||
var BannedExtensions = map[string]bool{
|
||||
// BannedExtensions 禁止的文件扩展名
|
||||
BannedExtensions = map[string]bool{
|
||||
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
||||
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
||||
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
||||
@@ -35,9 +36,11 @@ var BannedExtensions = map[string]bool{
|
||||
".class": true, ".pyc": true,
|
||||
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
||||
".ppt": true, ".pptx": true,
|
||||
}
|
||||
}
|
||||
|
||||
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||
multiNewlines = regexp.MustCompile(`\n{3,}`)
|
||||
)
|
||||
|
||||
// SanitizeURL 清洗 URL 字符串
|
||||
func SanitizeURL(raw string) string {
|
||||
@@ -51,25 +54,19 @@ func CleanSymbols(text string) string {
|
||||
text = symbolCleaner.ReplaceAllString(text, "")
|
||||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||
text = strings.ReplaceAll(text, "\r", "\n")
|
||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
||||
text = multiNewlines.ReplaceAllString(text, "\n\n")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// IsBannedExtension 判断是否为禁止的文件扩展名
|
||||
func IsBannedExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
ext := extractExtension(url)
|
||||
return BannedExtensions[ext]
|
||||
}
|
||||
|
||||
// IsZipExtension 判断是否为 zip 文件
|
||||
func IsZipExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
ext := extractExtension(url)
|
||||
return ext == ".zip"
|
||||
}
|
||||
|
||||
@@ -78,11 +75,22 @@ func IsReadableContentType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
ct := strings.ToLower(contentType)
|
||||
for _, prefix := range AllowedMIMEPrefixes {
|
||||
if strings.HasPrefix(ct, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// extractExtension 提取文件扩展名并清理查询参数
|
||||
func extractExtension(url string) string {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return ext
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||
func AsyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
@@ -18,9 +19,11 @@ func AsyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
@@ -28,25 +31,37 @@ func AsyncCtx(ctx context.Context) context.Context {
|
||||
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
setHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||
|
||||
fallbackToRequestHeaders(headers, ctx)
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// setHeaderFromContext 从上下文中设置 header
|
||||
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||
headers[headerKey] = value
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// fallbackToRequestHeaders 从请求头中获取作为兜底
|
||||
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 兜底:从请求头获取
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func ParseOutput(text string) (map[string]any, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
return j.Map(), nil
|
||||
}
|
||||
|
||||
@@ -23,26 +24,17 @@ func ConvertToMessages(raw any) []map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// 如果有 messages 字段,直接返回
|
||||
|
||||
if j.Contains("messages") {
|
||||
return gconv.Maps(j.Get("messages").Array())
|
||||
}
|
||||
// 否则当成单条 message
|
||||
return []map[string]any{
|
||||
j.Map(),
|
||||
}
|
||||
}
|
||||
|
||||
// IsMessageValid 校验推理结果是否合法
|
||||
func IsMessageValid(message map[string]any) bool {
|
||||
if message == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return []map[string]any{j.Map()}
|
||||
}
|
||||
|
||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||
@@ -50,6 +42,17 @@ func FormToJSON(form map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
||||
func UserFormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
@@ -60,39 +63,16 @@ func MustMarshal(v any) string {
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ParseJSONField 解析 JSON 字段
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
||||
func JSONPretty(v any) string {
|
||||
// 处理 *gvar.Var 类型
|
||||
if gv, ok := v.(*gvar.Var); ok {
|
||||
v = gconv.Map(gv.String())
|
||||
}
|
||||
|
||||
// 统一转 map 再美化
|
||||
var tmp map[string]any
|
||||
if err := gconv.Struct(v, &tmp); err != nil {
|
||||
return gconv.String(v)
|
||||
@@ -101,3 +81,71 @@ func JSONPretty(v any) string {
|
||||
b, _ := json.MarshalIndent(tmp, "", " ")
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
|
||||
func GvarToMap(v *gvar.Var) map[string]any {
|
||||
if v == nil || v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
// 方法1:尝试获取 map 值
|
||||
if m := v.Map(); len(m) > 0 {
|
||||
return m
|
||||
}
|
||||
|
||||
// 方法2:尝试解析 JSON 字符串
|
||||
str := v.String()
|
||||
if str != "" && str != "<nil>" {
|
||||
json.Unmarshal([]byte(str), &result)
|
||||
if len(result) > 0 {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// 方法3:尝试获取 interface 再转换
|
||||
if val := v.Val(); val != nil {
|
||||
switch val.(type) {
|
||||
case map[string]any:
|
||||
return val.(map[string]any)
|
||||
default:
|
||||
data, _ := json.Marshal(val)
|
||||
json.Unmarshal(data, &result)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
||||
func ParseJSONFieldFromGvar(source any, target any) {
|
||||
if source == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := source.(type) {
|
||||
case *gvar.Var:
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试获取 map
|
||||
if m := v.Map(); len(m) > 0 {
|
||||
data, _ := json.Marshal(m)
|
||||
json.Unmarshal(data, target)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析 JSON 字符串
|
||||
str := v.String()
|
||||
if str != "" && str != "<nil>" {
|
||||
json.Unmarshal([]byte(str), target)
|
||||
}
|
||||
|
||||
default:
|
||||
// 其他类型走原来的逻辑
|
||||
data, _ := json.Marshal(source)
|
||||
json.Unmarshal(data, target)
|
||||
}
|
||||
}
|
||||
|
||||
229
common/util/token.go
Normal file
229
common/util/token.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
var (
|
||||
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
|
||||
punctRegex = regexp.MustCompile(`[[:punct:]]`)
|
||||
)
|
||||
|
||||
// TokenConfig Token计算配置
|
||||
type TokenConfig struct {
|
||||
ZhRatio float64 `json:"zh_ratio"`
|
||||
EnRatio float64 `json:"en_ratio"`
|
||||
SpaceRatio float64 `json:"space_ratio"`
|
||||
PunctuationRatio float64 `json:"punctuation_ratio"`
|
||||
MaxWindowSize int `json:"max_window_size"`
|
||||
ReserveRatio float64 `json:"reserve_ratio"`
|
||||
MinReserve int `json:"min_reserve"`
|
||||
}
|
||||
|
||||
// CalculateTokens 计算文本token数
|
||||
func CalculateTokens(text string, tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
zhCount := countChineseChars(text)
|
||||
enCount := countEnglishWords(text)
|
||||
spaceCount := strings.Count(text, " ")
|
||||
punctCount := countPunctuation(text)
|
||||
|
||||
totalTokens := int(
|
||||
float64(zhCount)*config.ZhRatio +
|
||||
float64(enCount)*config.EnRatio +
|
||||
float64(spaceCount)*config.SpaceRatio +
|
||||
float64(punctCount)*config.PunctuationRatio,
|
||||
)
|
||||
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// CountToken 计算token是否超出窗口限制
|
||||
// 返回: true - 未超出(可用), false - 已超出(不可用)
|
||||
func CountToken(text string, tokenConfig any) bool {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
estimatedTokens := CalculateTokens(text, tokenConfig)
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
return estimatedTokens <= availableWindow
|
||||
}
|
||||
|
||||
// GetAvailableWindow 获取可用窗口大小
|
||||
func GetAvailableWindow(tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 4096
|
||||
}
|
||||
|
||||
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
|
||||
reserve := reserveByRatio
|
||||
|
||||
if config.MinReserve > reserve {
|
||||
reserve = config.MinReserve
|
||||
}
|
||||
|
||||
available := config.MaxWindowSize - reserve
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return available
|
||||
}
|
||||
|
||||
// GetMaxWindowSize 获取模型最大窗口大小
|
||||
func GetMaxWindowSize(tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 4096
|
||||
}
|
||||
|
||||
return config.MaxWindowSize
|
||||
}
|
||||
|
||||
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
|
||||
// 返回: isValid, exceedTokens, error
|
||||
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil || len(userForm) == 0 {
|
||||
return true, 0, nil
|
||||
}
|
||||
|
||||
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
if totalTokens > availableWindow {
|
||||
return false, totalTokens - availableWindow, nil
|
||||
}
|
||||
|
||||
return true, 0, nil
|
||||
}
|
||||
|
||||
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
|
||||
// 返回: 需要拆分的批次数, 每批的token数, 错误
|
||||
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil || len(userForm) == 0 {
|
||||
return 1, nil, nil
|
||||
}
|
||||
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
batches := 1
|
||||
currentTokens := 0
|
||||
batchTokens := make([]int, 0)
|
||||
|
||||
for _, item := range userForm {
|
||||
itemStr := fmt.Sprintf("%v", item)
|
||||
itemTokens := CalculateTokens(itemStr, tokenConfig)
|
||||
|
||||
if currentTokens+itemTokens > availableWindow {
|
||||
batchTokens = append(batchTokens, currentTokens)
|
||||
batches++
|
||||
currentTokens = itemTokens
|
||||
} else {
|
||||
currentTokens += itemTokens
|
||||
}
|
||||
}
|
||||
|
||||
if currentTokens > 0 {
|
||||
batchTokens = append(batchTokens, currentTokens)
|
||||
}
|
||||
|
||||
return batches, batchTokens, nil
|
||||
}
|
||||
|
||||
// parseConfig 解析配置
|
||||
func parseConfig(tokenConfig any) *TokenConfig {
|
||||
if tokenConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := tokenConfig.(type) {
|
||||
case *gvar.Var:
|
||||
return parseGVarConfig(v)
|
||||
case map[string]any:
|
||||
return parseMapConfig(v)
|
||||
case *TokenConfig:
|
||||
return v
|
||||
case TokenConfig:
|
||||
return &v
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseGVarConfig 解析 GVar 配置
|
||||
func parseGVarConfig(v *gvar.Var) *TokenConfig {
|
||||
if v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapVal := v.Map()
|
||||
if mapVal == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &TokenConfig{}
|
||||
data, _ := json.Marshal(mapVal)
|
||||
json.Unmarshal(data, config)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// parseMapConfig 解析 Map 配置
|
||||
func parseMapConfig(v map[string]any) *TokenConfig {
|
||||
config := &TokenConfig{}
|
||||
data, _ := json.Marshal(v)
|
||||
json.Unmarshal(data, config)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// countChineseChars 统计中文字符数量
|
||||
func countChineseChars(text string) int {
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if unicode.Is(unicode.Han, r) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// countEnglishWords 统计英文单词数量
|
||||
func countEnglishWords(text string) int {
|
||||
return len(enWordRegex.FindAllString(text, -1))
|
||||
}
|
||||
|
||||
// countPunctuation 统计标点符号数量
|
||||
func countPunctuation(text string) int {
|
||||
return len(punctRegex.FindAllString(text, -1))
|
||||
}
|
||||
|
||||
// calculateUserFormTokens 计算 UserForm 总 token 数
|
||||
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
|
||||
totalTokens := 0
|
||||
for _, item := range userForm {
|
||||
itemStr := fmt.Sprintf("%v", item)
|
||||
totalTokens += CalculateTokens(itemStr, tokenConfig)
|
||||
}
|
||||
return totalTokens
|
||||
}
|
||||
82
config.yml
82
config.yml
@@ -103,49 +103,7 @@ modelPrompts:
|
||||
在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。
|
||||
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
|
||||
|
||||
buildProject:
|
||||
types:
|
||||
1: |
|
||||
你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
||||
【强制规则】
|
||||
必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
||||
完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
||||
完整读取UserForm所有字段,不得忽略任何字段;
|
||||
如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
||||
理解全部语义后再输出,禁止断章取义;
|
||||
UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
||||
【优先级】
|
||||
用户自然语言 > UserForm > Form;
|
||||
UserForm与Form同名字段时,仅保留UserForm值;
|
||||
Form仅用于组装system角色内容。
|
||||
【表单处理】
|
||||
Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
||||
UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
||||
自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和(示例:10条文案各配5张图 → 总50张,parameters.n=50),用户没有相关数量必须默认1;
|
||||
图片尺寸为空时自动填充size=1024*1024。
|
||||
【结构铁律】
|
||||
严格沿用固定输出结构,不增删字段或修改层级;
|
||||
messages元素必须按结构返回;
|
||||
禁止将role对象转为字符串、禁止嵌套错乱;
|
||||
输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
||||
所有括号、引号必须成对闭合,保证JSON合法。
|
||||
【参数赋值】
|
||||
model固定沿用传入值;
|
||||
返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
||||
history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
||||
从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
||||
【输出要求】
|
||||
仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
||||
完整合UserForm全部字段语义到user描述;
|
||||
生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
||||
【输出结构】
|
||||
%s
|
||||
【字段映射】
|
||||
%s
|
||||
【完整输入信息】
|
||||
%s
|
||||
直接输出最终JSON:
|
||||
2: |
|
||||
nodePrompts: |
|
||||
你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。
|
||||
规则:
|
||||
1. 只允许从下面的可选节点ID列表中选择一个返回
|
||||
@@ -155,3 +113,41 @@ buildProject:
|
||||
%s
|
||||
上下文内容:
|
||||
%s
|
||||
|
||||
#你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
||||
# 【强制规则】
|
||||
# 必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
||||
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
||||
# 完整读取UserForm所有字段,不得忽略任何字段;
|
||||
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
||||
# 理解全部语义后再输出,禁止断章取义;
|
||||
# UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
||||
# 【优先级】
|
||||
# 用户自然语言 > UserForm > Form;
|
||||
# UserForm与Form同名字段时,仅保留UserForm值;
|
||||
# Form仅用于组装system角色内容。
|
||||
# 【表单处理】
|
||||
# Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
||||
# UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
||||
# 自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和,用户没有相关数量必须默认1;
|
||||
# 图片尺寸为空时自动填充size=1024*1024。
|
||||
# 【结构铁律】
|
||||
# 严格沿用固定输出结构,不增删字段或修改层级;
|
||||
# messages元素必须按结构返回;
|
||||
# 禁止将role对象转为字符串、禁止嵌套错乱;
|
||||
# 输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
||||
# 所有括号、引号必须成对闭合,保证JSON合法。
|
||||
# 【参数赋值】
|
||||
# model固定沿用传入值;
|
||||
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
||||
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
||||
# 从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
||||
# 【输出要求】
|
||||
# 仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
||||
# 完整合UserForm全部字段语义到user描述;
|
||||
# 生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
||||
# 【输出结构】
|
||||
# %s
|
||||
# 【完整输入信息】
|
||||
# %s
|
||||
# 直接输出最终JSON:
|
||||
@@ -5,3 +5,8 @@ const (
|
||||
ComposeStatusSuccess = "success"
|
||||
ComposeStatusFailed = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
BuildTypePrompt = 1 //提示词构建
|
||||
BuildTypeNode = 2 //节点构建
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package prompt
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/model/dto"
|
||||
|
||||
promptDto "prompts-core/model/dto/prompt"
|
||||
promptService "prompts-core/service/prompt"
|
||||
)
|
||||
|
||||
@@ -13,17 +13,17 @@ type prompt struct{}
|
||||
var Prompt = new(prompt)
|
||||
|
||||
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
|
||||
func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) {
|
||||
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
|
||||
return promptService.ComposeMessages(ctx, req)
|
||||
}
|
||||
|
||||
// Callback model-gateway 提示词回调
|
||||
func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) {
|
||||
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) {
|
||||
err = promptService.Callback(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetComposeTask 查询拼接任务结果
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) {
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package prompt
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/model/dto"
|
||||
|
||||
promptDto "prompts-core/model/dto/prompt"
|
||||
promptService "prompts-core/service/prompt"
|
||||
)
|
||||
|
||||
@@ -13,6 +13,6 @@ type session struct{}
|
||||
var Session = new(session)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) {
|
||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||
return promptService.SessionCallback(ctx, req)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ type composeSessionDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||
var m = new(entity.ComposeTask)
|
||||
var m = new(entity.ComposeSession)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ProviderProtocol = &providerProtocolDao{}
|
||||
@@ -14,7 +15,13 @@ type providerProtocolDao struct{}
|
||||
|
||||
// Insert 新增协议配置
|
||||
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert()
|
||||
var m = new(entity.ProviderProtocol)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
6
main.go
6
main.go
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"prompts-core/controller/prompt"
|
||||
"prompts-core/controller"
|
||||
"syscall"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
@@ -21,8 +21,8 @@ func main() {
|
||||
defer jaeger.ShutDown(ctx)
|
||||
// 注册路由
|
||||
http.RouteRegister([]interface{}{
|
||||
prompt.Prompt,
|
||||
prompt.Session,
|
||||
controller.Prompt,
|
||||
controller.Session,
|
||||
})
|
||||
|
||||
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package prompt
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
@@ -9,16 +9,22 @@ type ComposeMessagesReq struct {
|
||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||
UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
|
||||
}
|
||||
|
||||
type ComposeMessagesRes struct {
|
||||
Messages any `json:"messages,omitempty" dc:"最终消息数组"`
|
||||
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
// MultiRoundResult 多轮返回结果
|
||||
type MultiRoundResult struct {
|
||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||
Rounds []any `json:"rounds"` // 每轮详情(动态类型)
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
@@ -1,4 +1,4 @@
|
||||
package prompt
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
@@ -16,10 +16,10 @@ type AsynchModel struct {
|
||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled int `orm:"enabled" json:"enabled"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
@@ -28,6 +28,9 @@ type AsynchModel struct {
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig any `orm:"token_config" json:"tokenConfig"`
|
||||
}
|
||||
|
||||
type asynchModelCol struct {
|
||||
@@ -55,6 +58,9 @@ type asynchModelCol struct {
|
||||
RetryQueueMaxSecs string
|
||||
AutoCleanSeconds string
|
||||
Remark string
|
||||
IsOwner string
|
||||
OperatorName string
|
||||
TokenConfig string
|
||||
}
|
||||
|
||||
var AsynchModelCol = asynchModelCol{
|
||||
@@ -82,4 +88,7 @@ var AsynchModelCol = asynchModelCol{
|
||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
Remark: "remark",
|
||||
IsOwner: "is_owner",
|
||||
OperatorName: "operator_name",
|
||||
TokenConfig: "token_config",
|
||||
}
|
||||
|
||||
@@ -4,23 +4,41 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/consts/public"
|
||||
"strings"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto/prompt"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// buildInferenceRequest 构建返回请求
|
||||
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||
// buildInferenceRequest 构建推理请求
|
||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
|
||||
ir := NewPromptIR()
|
||||
// 1. 统一 Prompt IR
|
||||
|
||||
switch req.BuildType {
|
||||
case 1: //构建提示词请求
|
||||
ir.AddSystem(promptBuild(ctx, req, model))
|
||||
case public.BuildTypePrompt:
|
||||
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
|
||||
case public.BuildTypeNode:
|
||||
return buildNodeTypeRequest(ctx, req, ir)
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
}
|
||||
}
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
|
||||
ir.AddSystem(systemPrompt)
|
||||
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
@@ -28,41 +46,71 @@ func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq,
|
||||
}
|
||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
||||
}
|
||||
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
|
||||
case 2: //构建节点请求
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
|
||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
|
||||
ir.AddUser(userPrompt)
|
||||
|
||||
if !checkOverallContent(ir, targetModel) {
|
||||
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
|
||||
// 2. 获取协议配置
|
||||
protocol, err := GetProtocolByProvider(ctx, "qwen")
|
||||
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
|
||||
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||
}
|
||||
if protocol == nil {
|
||||
return nil, errors.New("协议配置不存在")
|
||||
}
|
||||
|
||||
// 3. 编译为 Provider Request
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
providerReq, err := Compile(ir, protocol, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 构建请求体
|
||||
return map[string]any{
|
||||
"modelName": chatModel.ModelName,
|
||||
"modelName": req.ModelName,
|
||||
"bizName": "prompts-core",
|
||||
"callbackUrl": "/prompt/callback",
|
||||
"requestPayload": providerReq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuild 构建系统提示词
|
||||
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||
}
|
||||
if protocol == nil {
|
||||
return nil, errors.New("协议配置不存在")
|
||||
}
|
||||
|
||||
providerReq, err := Compile(ir, protocol, model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
|
||||
return map[string]any{
|
||||
"modelName": model.ModelName,
|
||||
"bizName": "prompts-core",
|
||||
"callbackUrl": "/prompt/callback",
|
||||
"requestPayload": providerReq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
|
||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
|
||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: "qwen",
|
||||
ProviderName: model.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil || providerProtocol == nil {
|
||||
@@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent
|
||||
}
|
||||
|
||||
outputJSON := util.JSONPretty(model.RequestMapping)
|
||||
var userFormContent strings.Builder
|
||||
for k, v := range req.UserForm {
|
||||
userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v))
|
||||
}
|
||||
userFormFullText := strings.TrimSuffix(userFormContent.String(), ";")
|
||||
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
|
||||
userFormContent := buildUserFormContent(req.UserForm)
|
||||
formInfo := fmt.Sprintf(`
|
||||
【系统表单(系统提示词/参数)】
|
||||
%s
|
||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||||
%s
|
||||
`, util.FormToJSON(req.Form), userFormFullText)
|
||||
`, util.FormToJSON(req.Form), userFormContent)
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
|
||||
inputInfo := fmt.Sprintf(`
|
||||
目标模型: %s
|
||||
%s
|
||||
技能名称: %s
|
||||
用户文件: %v
|
||||
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
req.ModelName,
|
||||
maxWindowSize,
|
||||
availableWindow,
|
||||
totalRounds,
|
||||
totalRounds,
|
||||
totalRounds,
|
||||
outputJSON,
|
||||
inputInfo,
|
||||
totalRounds,
|
||||
)
|
||||
}
|
||||
|
||||
// 构建用户提示词
|
||||
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
|
||||
payload := map[string]any{
|
||||
"model": req.ModelName, // 请求模型名称
|
||||
"promptInfo": prompt, // 数据库提示信息
|
||||
"form": req.Form, // 系统表单
|
||||
"userForm": req.UserForm, // 用户表单
|
||||
"userFiles": req.UserFiles, //文件url
|
||||
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如:xml,json,yaml)
|
||||
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
|
||||
// buildUserFormContent 构建用户表单内容字符串
|
||||
func buildUserFormContent(userForm []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range userForm {
|
||||
builder.WriteString(fmt.Sprintf("%v\n", item))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
return util.CountToken(fullContent, model.TokenConfig)
|
||||
}
|
||||
|
||||
// buildUserPrompt 构建用户提示词
|
||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||
userFormForPayload := prepareUserFormPayload(req.UserForm)
|
||||
|
||||
payload := map[string]any{
|
||||
"model": req.ModelName,
|
||||
"promptInfo": prompt,
|
||||
"form": req.Form,
|
||||
"userForm": userFormForPayload,
|
||||
"userFiles": req.UserFiles,
|
||||
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
|
||||
"skills": SkillMdContent(ctx, req.SkillName),
|
||||
}
|
||||
|
||||
return util.MustMarshal(payload)
|
||||
}
|
||||
|
||||
// prepareUserFormPayload 准备用户表单载荷
|
||||
func prepareUserFormPayload(userForm []map[string]any) any {
|
||||
if len(userForm) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := userForm[0]["batch_index"]; ok {
|
||||
return userForm
|
||||
}
|
||||
|
||||
return mergeUserFormTexts(userForm)
|
||||
}
|
||||
|
||||
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
|
||||
func mergeUserFormTexts(userForm []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for i, item := range userForm {
|
||||
text := getItemText(item)
|
||||
if i > 0 {
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
builder.WriteString(text)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// NodeBuild 节点构建
|
||||
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
|
||||
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
|
||||
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||
promptTpl := util.GetBuildPrompt(ctx)
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
formStr := util.FormToJSON(req.Form)
|
||||
userFormStr := util.FormToJSON(req.UserForm)
|
||||
userFormStr := util.UserFormToJSON(req.UserForm)
|
||||
|
||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||||
}
|
||||
|
||||
@@ -5,171 +5,229 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
promptDto "prompts-core/model/dto/prompt"
|
||||
"prompts-core/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/service/gateway"
|
||||
)
|
||||
|
||||
// ComposeMessages 核心拼接提示词主流程
|
||||
func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
|
||||
var (
|
||||
epicycleId int64
|
||||
taskID string
|
||||
history []map[string]any
|
||||
message map[string]any
|
||||
err error
|
||||
taskRecord *entity.ComposeTask
|
||||
)
|
||||
// 获取模型信息
|
||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 根据构建类型进行判断处理
|
||||
if err = validateUserForm(ctx, req, aiModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch req.BuildType {
|
||||
//提示词构建
|
||||
case 1:
|
||||
case public.BuildTypePrompt:
|
||||
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
||||
case public.BuildTypeNode:
|
||||
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
|
||||
default:
|
||||
return handleDefaultCase(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// validateUserForm 校验用户表单
|
||||
func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
|
||||
if len(req.UserForm) == 0 {
|
||||
return nil
|
||||
}
|
||||
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("校验用户表单失败: %w", err)
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
||||
exceedTokens, availableWindow)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePromptBuild 处理提示词构建(BuildType=1)
|
||||
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
|
||||
//1. 获取历史会话
|
||||
history, err = GetHistoryMessages(ctx, req.SessionId)
|
||||
history, err := GetHistoryMessages(ctx, req.SessionId)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
||||
history = nil // 出错就用空的,不影响主流程
|
||||
history = nil
|
||||
}
|
||||
// 重试循环
|
||||
|
||||
var message *dto.MultiRoundResult
|
||||
var taskRecord *entity.ComposeTask
|
||||
for attempt := 0; attempt <= 0; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
|
||||
}
|
||||
// 2. 调用推理模型
|
||||
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||
|
||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 保存记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
RequestPayload: util.MustMarshal(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
if err != nil {
|
||||
if err = saveComposeTask(ctx, taskID, req); err != nil {
|
||||
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 4. 等待结果
|
||||
taskRecord, err = waitForResult(ctx, taskID)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
// 校验结果
|
||||
|
||||
message = parsePromptBuild(taskRecord, chatModel)
|
||||
if message != nil && util.IsMessageValid(message) {
|
||||
if message != nil {
|
||||
break
|
||||
}
|
||||
|
||||
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
|
||||
message = nil
|
||||
}
|
||||
|
||||
if message == nil {
|
||||
return nil, errors.New("推理模型调用失败,请稍后再试")
|
||||
}
|
||||
//5.创建会话记录
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: message,
|
||||
})
|
||||
//节点构建
|
||||
case 2:
|
||||
//1. 调用推理模型
|
||||
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
g.Log().Errorf(ctx, "创建会话记录失败: %v", err)
|
||||
}
|
||||
//2. 保存相关记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
return &dto.ComposeMessagesRes{
|
||||
Messages: message,
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||
}
|
||||
|
||||
if err := saveComposeTask(ctx, taskID, req); err != nil {
|
||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||
}
|
||||
|
||||
taskRecord, err := waitForResult(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("等待结果失败: %w", err)
|
||||
}
|
||||
|
||||
message := parseNodeBuild(taskRecord)
|
||||
|
||||
return &dto.ComposeMessagesRes{
|
||||
Messages: message,
|
||||
EpicycleId: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleDefaultCase 处理默认情况
|
||||
func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
Remark: req.Cause,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建会话记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.ComposeMessagesRes{
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// saveComposeTask 保存组合任务
|
||||
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
|
||||
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
RequestPayload: util.MustMarshal(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
//5. 等待结果
|
||||
taskRecord, err := waitForResult(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message = parseNodeBuild(taskRecord)
|
||||
default:
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
Remark: req.Cause,
|
||||
})
|
||||
return &promptDto.ComposeMessagesRes{
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
}
|
||||
return &promptDto.ComposeMessagesRes{
|
||||
Messages: message,
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
return err
|
||||
}
|
||||
|
||||
// GetModelMessage 获取模型信息
|
||||
func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
||||
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
// 1. 获取当前用户的会话模型
|
||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
IsChatModel: 1,
|
||||
})
|
||||
|
||||
chatModel, err := getChatModel(ctx, userInfo.UserName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
// 2. 获取要构建的模型信息
|
||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
|
||||
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if aiModel == nil {
|
||||
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
|
||||
}
|
||||
|
||||
return chatModel, aiModel, nil
|
||||
}
|
||||
|
||||
// getChatModel 获取聊天模型
|
||||
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
|
||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
|
||||
}
|
||||
|
||||
if chatModel == nil {
|
||||
return nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
|
||||
return chatModel, nil
|
||||
}
|
||||
|
||||
// getAIModel 获取AI模型
|
||||
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
|
||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||
ModelName: modelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询AI模型失败: %w", err)
|
||||
}
|
||||
|
||||
if aiModel == nil {
|
||||
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
|
||||
}
|
||||
|
||||
return aiModel, nil
|
||||
}
|
||||
|
||||
// callInferenceModel 调用推理模型
|
||||
func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
|
||||
// 构建推理模型请求
|
||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
|
||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建网关任务
|
||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
||||
@@ -186,96 +244,131 @@ func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq,
|
||||
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
||||
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
// ===================== 修复点 1:检查上下文是否取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 请求已被取消,直接返回,不继续查库
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 1. 查数据库
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
// ===================== 修复点 2:如果是上下文取消,直接返回 =====================
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
|
||||
if record != nil {
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
if completed, result := checkTaskCompletion(record); completed {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 查网关状态
|
||||
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
||||
if err != nil {
|
||||
// 网关不可达不终止,继续轮询
|
||||
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
|
||||
} else {
|
||||
switch state {
|
||||
case 2: // 网关成功
|
||||
// 网关已成功,主动更新数据库
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
case 3: // 网关失败
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: "model-gateway 任务执行失败",
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||
}
|
||||
if err = syncGatewayTaskState(ctx, taskID, record); err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
|
||||
// 3. 超时检查
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// ===================== 修复点3:sleep 也要监听 ctx 取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkTaskCompletion 检查任务是否完成
|
||||
func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) {
|
||||
if record == nil {
|
||||
return false, nil
|
||||
}
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return true, record
|
||||
case public.ComposeStatusFailed:
|
||||
errMsg := strings.TrimSpace(record.ErrorMessage)
|
||||
if errMsg == "" {
|
||||
return true, nil
|
||||
}
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// syncGatewayTaskState 同步网关任务状态
|
||||
func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error {
|
||||
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询网关状态失败: %w", err)
|
||||
}
|
||||
switch state {
|
||||
case 2:
|
||||
return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "")
|
||||
case 3:
|
||||
updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败")
|
||||
return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTaskStatus 更新任务状态
|
||||
func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error {
|
||||
task := &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: status,
|
||||
}
|
||||
if errorMsg != "" {
|
||||
task.ErrorMessage = errorMsg
|
||||
}
|
||||
|
||||
_, err := dao.ComposeTask.Update(ctx, task)
|
||||
return err
|
||||
}
|
||||
|
||||
// parsePromptBuild 解析提示词构建结果(BuildType == 1)
|
||||
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
|
||||
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult {
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
mapped := parseTaskMessages(taskRecord.Messages)
|
||||
if mapped == nil {
|
||||
return createDefaultResult(nil)
|
||||
}
|
||||
|
||||
// 1. 解析 Messages
|
||||
contentField := getContentField(model)
|
||||
contentStr, ok := mapped[contentField].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return createDefaultResult(mapped)
|
||||
}
|
||||
|
||||
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: len(roundsArray),
|
||||
Rounds: roundsArray,
|
||||
}
|
||||
}
|
||||
|
||||
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{singleRound},
|
||||
}
|
||||
}
|
||||
|
||||
return createDefaultResult(map[string]any{"content": contentStr})
|
||||
}
|
||||
|
||||
// parseTaskMessages 解析任务消息
|
||||
func parseTaskMessages(messages any) map[string]any {
|
||||
var mapped map[string]any
|
||||
switch v := taskRecord.Messages.(type) {
|
||||
|
||||
switch v := messages.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &mapped)
|
||||
@@ -289,115 +382,137 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
||||
json.Unmarshal(b, &mapped)
|
||||
}
|
||||
|
||||
// 2. 解析模型 ResponseMapping 获取 content 字段名
|
||||
contentField := "content" // 默认值
|
||||
if model != nil {
|
||||
var respMapping map[string]string
|
||||
switch v := model.ResponseMapping.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &respMapping)
|
||||
}
|
||||
case string:
|
||||
json.Unmarshal([]byte(v), &respMapping)
|
||||
case map[string]interface{}:
|
||||
respMapping = make(map[string]string)
|
||||
for k, val := range v {
|
||||
if s, ok := val.(string); ok {
|
||||
respMapping[k] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
// 从映射中找到 content 对应的字段名
|
||||
for k, v := range respMapping {
|
||||
if strings.Contains(v, "content") {
|
||||
contentField = k
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 提取 content 的值
|
||||
contentStr, ok := mapped[contentField].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 4. 解析 content 内的 JSON
|
||||
var innerData map[string]any
|
||||
json.Unmarshal([]byte(contentStr), &innerData)
|
||||
|
||||
return innerData
|
||||
}
|
||||
|
||||
// parseNodeBuild 解析节点构建结果(BuildType == 2)
|
||||
func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
|
||||
if taskRecord == nil {
|
||||
// tryParseAsArray 尝试将字符串解析为数组
|
||||
func tryParseAsArray(contentStr string) []any {
|
||||
var roundsArray []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil {
|
||||
return nil
|
||||
}
|
||||
var result map[string]any
|
||||
switch v := taskRecord.Messages.(type) {
|
||||
return roundsArray
|
||||
}
|
||||
|
||||
// tryParseAsObject 尝试将字符串解析为对象
|
||||
func tryParseAsObject(contentStr string) any {
|
||||
var singleRound any
|
||||
if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil {
|
||||
return nil
|
||||
}
|
||||
return singleRound
|
||||
}
|
||||
|
||||
// createDefaultResult 创建默认结果
|
||||
func createDefaultResult(data any) *dto.MultiRoundResult {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{data},
|
||||
}
|
||||
}
|
||||
|
||||
// getContentField 从模型 ResponseMapping 中获取 content 字段名
|
||||
func getContentField(model *entity.AsynchModel) string {
|
||||
if model == nil {
|
||||
return "content"
|
||||
}
|
||||
|
||||
respMapping := parseResponseMapping(model.ResponseMapping)
|
||||
for k, v := range respMapping {
|
||||
if strings.Contains(v, "content") {
|
||||
return k
|
||||
}
|
||||
}
|
||||
|
||||
return "content"
|
||||
}
|
||||
|
||||
// parseResponseMapping 解析响应映射
|
||||
func parseResponseMapping(mapping any) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
switch v := mapping.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &result)
|
||||
}
|
||||
case string:
|
||||
json.Unmarshal([]byte(v), &result)
|
||||
case map[string]any:
|
||||
result = v
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
json.Unmarshal(b, &result)
|
||||
case map[string]interface{}:
|
||||
for k, val := range v {
|
||||
if s, ok := val.(string); ok {
|
||||
result[k] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// parseNodeBuild 解析节点构建结果(BuildType == 2)
|
||||
func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := parseTaskMessages(taskRecord.Messages)
|
||||
if result == nil {
|
||||
result = make(map[string]any)
|
||||
}
|
||||
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{result},
|
||||
}
|
||||
}
|
||||
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
|
||||
// ============ 先查任务是否存在 ============
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if task == nil {
|
||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||
}
|
||||
// ============ 根据状态区分处理 ============
|
||||
|
||||
if req.State == 3 {
|
||||
// 失败:直接更新状态
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
})
|
||||
return err
|
||||
}
|
||||
// ======================================
|
||||
// 成功:解析模型输出
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
})
|
||||
if updateErr != nil {
|
||||
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
|
||||
}
|
||||
return err
|
||||
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
}
|
||||
|
||||
return handleCallbackSuccess(ctx, req)
|
||||
}
|
||||
|
||||
// handleCallbackFailure 处理回调失败
|
||||
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: errorMsg,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
return fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
// ============ result 可能为 nil ============
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
// =======================================
|
||||
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
@@ -407,34 +522,43 @@ func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// 如果 Messages 是字符串,反序列化为 JSON 数组
|
||||
messages := record.Messages
|
||||
if str, ok := messages.(string); ok && str != "" {
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
messages = parsed
|
||||
}
|
||||
}
|
||||
messages := parseMessagesForResponse(record.Messages)
|
||||
|
||||
return &promptDto.GetComposeTaskRes{
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseMessagesForResponse 解析用于响应的消息
|
||||
func parseMessagesForResponse(messages any) any {
|
||||
str, ok := messages.(string)
|
||||
if !ok || str == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -10,10 +10,15 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/service/gateway"
|
||||
)
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
const (
|
||||
bytesPerKB = 1024
|
||||
bytesPerMB = 1024 * 1024
|
||||
)
|
||||
|
||||
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
|
||||
@@ -24,51 +29,49 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
return result
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
|
||||
}
|
||||
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
||||
|
||||
for _, rawURL := range urls {
|
||||
url := util.SanitizeURL(rawURL)
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if util.IsBannedExtension(url) {
|
||||
if url == "" || util.IsBannedExtension(url) {
|
||||
continue
|
||||
}
|
||||
|
||||
if util.IsZipExtension(url) {
|
||||
zipTexts := fetchZipFileTexts(ctx, client, url)
|
||||
for k, v := range zipTexts {
|
||||
result[k] = v
|
||||
}
|
||||
mergeMap(result, fetchZipFileTexts(ctx, client, url))
|
||||
continue
|
||||
}
|
||||
|
||||
text, err := fetchFileContent(ctx, client, url)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
text = util.CleanSymbols(text)
|
||||
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
||||
result[url] = text
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeMap 合并 map
|
||||
func mergeMap(dst, src map[string]string) {
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// fetchAndCleanFileContent 获取并清理文件内容
|
||||
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
|
||||
text, err := fetchFileContent(ctx, client, url)
|
||||
if err != nil || text == "" {
|
||||
return ""
|
||||
}
|
||||
return util.CleanSymbols(text)
|
||||
}
|
||||
|
||||
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
|
||||
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
zipBytes, err := downloadFile(client, url,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||
zipBytes, err := downloadFile(client, url, maxSize)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
@@ -78,61 +81,61 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
|
||||
return result
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
if shouldSkipZipEntry(file.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
fileName := file.Name
|
||||
|
||||
if util.IsBannedExtension(fileName) {
|
||||
continue
|
||||
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
|
||||
result[url+"::"+file.Name] = text
|
||||
}
|
||||
}
|
||||
|
||||
if util.IsZipExtension(fileName) {
|
||||
continue
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
|
||||
func shouldSkipZipEntry(fileName string) bool {
|
||||
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
|
||||
}
|
||||
|
||||
// extractZipEntryContent 提取 zip 条目内容
|
||||
func extractZipEntryContent(file *zip.File, maxSize int64) string {
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
return ""
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||
if err != nil {
|
||||
continue
|
||||
return ""
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(content)
|
||||
if !util.IsReadableContentType(contentType) {
|
||||
continue
|
||||
if !util.IsReadableContentType(http.DetectContentType(content)) {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := util.CleanSymbols(string(content))
|
||||
if text == "" {
|
||||
continue
|
||||
return ""
|
||||
}
|
||||
|
||||
key := url + "::" + fileName
|
||||
result[key] = text
|
||||
}
|
||||
|
||||
return result
|
||||
return text
|
||||
}
|
||||
|
||||
// downloadFile 下载文件,限制最大大小
|
||||
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("执行请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -140,19 +143,24 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// fetchFileContent 获取单个文本文件内容
|
||||
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("执行请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -162,16 +170,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !util.IsReadableContentType(contentType) {
|
||||
return "", fmt.Errorf("unreadable content-type: %s", contentType)
|
||||
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(
|
||||
io.LimitReader(resp.Body,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
|
||||
),
|
||||
)
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
@@ -186,27 +191,26 @@ func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
|
||||
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
|
||||
}
|
||||
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||
|
||||
zipBytes, err := downloadFile(client, fullUrl,
|
||||
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
zipBytes, err := downloadFile(client, fullUrl, maxSize)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||
if err != nil {
|
||||
if err != nil || len(mdContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(mdContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
return buildSkillMarkdown(skillResp, mdContents)
|
||||
}
|
||||
|
||||
// buildSkillMarkdown 构建技能 Markdown 内容
|
||||
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
|
||||
var builder strings.Builder
|
||||
|
||||
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
||||
if skillResp.Description != "" {
|
||||
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
||||
@@ -227,35 +231,53 @@ func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, er
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
result[file.Name] = strings.TrimSpace(string(content))
|
||||
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
|
||||
result[file.Name] = content
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isMarkdownFile 判断是否为 Markdown 文件
|
||||
func isMarkdownFile(fileName string) bool {
|
||||
return strings.HasSuffix(strings.ToLower(fileName), ".md")
|
||||
}
|
||||
|
||||
// readMarkdownFileContent 读取 Markdown 文件内容
|
||||
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
|
||||
// createHTTPClient 创建 HTTP 客户端
|
||||
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
75
service/prompt/prompt_files_handle_service.markdown
Normal file
75
service/prompt/prompt_files_handle_service.markdown
Normal file
@@ -0,0 +1,75 @@
|
||||
# Prompts-Core(提示词核心服务)
|
||||
|
||||
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
|
||||
|
||||
---
|
||||
|
||||
## 项目简介
|
||||
|
||||
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
|
||||
|
||||
### 核心价值
|
||||
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
|
||||
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
|
||||
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
|
||||
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
|
||||
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
|
||||
|
||||
---
|
||||
|
||||
## 核心功能
|
||||
|
||||
### 1. 提示词构建引擎
|
||||
|
||||
#### 多模态支持
|
||||
| 类型 | 说明 | 适用场景 |
|
||||
|------|------|----------|
|
||||
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
|
||||
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
|
||||
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
|
||||
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
|
||||
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
|
||||
|
||||
#### 构建模式
|
||||
- **BuildType 1(提示词构建)**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
|
||||
- **BuildType 2(节点构建)**:工作流路由决策,根据上下文选择节点 ID
|
||||
|
||||
#### 分批处理
|
||||
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
|
||||
|
||||
### 2. 会话管理系统
|
||||
|
||||
- **双层存储**:Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
|
||||
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
|
||||
|
||||
### 3. 协议适配器
|
||||
|
||||
通过配置动态支持多种模型协议:
|
||||
- 角色映射:system/user/assistant → 目标协议角色
|
||||
- 内容字段映射:content → parts.text 等
|
||||
- 消息顺序控制:灵活配置拼接顺序
|
||||
- 请求模板渲染:支持占位符替换
|
||||
|
||||
### 4. 任务调度
|
||||
|
||||
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
|
||||
- **重试机制**:可配置最大重试次数(默认 3 次)
|
||||
- **超时保护**:默认 300 秒超时
|
||||
|
||||
---
|
||||
|
||||
## 技术架构
|
||||
|
||||
### 技术栈
|
||||
|
||||
| 组件 | 版本 | 用途 |
|
||||
|------|------|------|
|
||||
| Go | 1.26.0 | 编程语言 |
|
||||
| GoFrame | v2.10.0 | Web 框架 |
|
||||
| PostgreSQL | - | 关系型数据库 |
|
||||
| Redis | - | 缓存与会话存储 |
|
||||
| Consul | - | 服务注册与发现 |
|
||||
| Jaeger | - | 分布式链路追踪 |
|
||||
|
||||
### 架构图
|
||||
|
||||
@@ -20,11 +20,27 @@ type PromptIR struct {
|
||||
|
||||
// Segment 消息片段
|
||||
type Segment struct {
|
||||
Type string `json:"type"` // text/image
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||
type ProviderProtocol struct {
|
||||
TargetField string `json:"target_field"`
|
||||
MergeOrder []string `json:"merge_order"`
|
||||
RoleMapping map[string]string `json:"role_mapping"`
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
type ContentMapping struct {
|
||||
Type string `json:"type"`
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
// NewPromptIR 创建空 PromptIR
|
||||
func NewPromptIR() *PromptIR {
|
||||
return &PromptIR{
|
||||
@@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR {
|
||||
}
|
||||
}
|
||||
|
||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||
func (ir *PromptIR) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString("System: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Role)
|
||||
builder.WriteString(": ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString("User: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||
func (ir *PromptIR) GetTotalContent() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// AddSystem 添加系统提示
|
||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
if content != "" {
|
||||
@@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
// 1. 系统消息
|
||||
for _, seg := range ir.System {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
@@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 历史消息
|
||||
for _, seg := range ir.History {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": seg.Role,
|
||||
@@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 用户消息
|
||||
for _, seg := range ir.User {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -97,11 +159,7 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
||||
if err != nil || entity == nil {
|
||||
return nil, err
|
||||
}
|
||||
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
|
||||
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
|
||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
||||
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
|
||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
||||
fmt.Println("entity打印", entity)
|
||||
return parseProtocol(entity), nil
|
||||
}
|
||||
|
||||
@@ -109,62 +167,27 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
p := &ProviderProtocol{
|
||||
TargetField: e.TargetField,
|
||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||
}
|
||||
|
||||
// MergeOrder: any → []string
|
||||
if e.MergeOrder != nil {
|
||||
b, _ := json.Marshal(e.MergeOrder)
|
||||
json.Unmarshal(b, &p.MergeOrder)
|
||||
}
|
||||
|
||||
// RoleMapping: any → map[string]string
|
||||
if e.RoleMapping != nil {
|
||||
b, _ := json.Marshal(e.RoleMapping)
|
||||
json.Unmarshal(b, &p.RoleMapping)
|
||||
}
|
||||
|
||||
// ContentMapping: any → ContentMapping
|
||||
if e.ContentMapping != nil {
|
||||
b, _ := json.Marshal(e.ContentMapping)
|
||||
json.Unmarshal(b, &p.ContentMapping)
|
||||
}
|
||||
|
||||
// RequestTemplate: any → map[string]any
|
||||
if e.RequestTemplate != nil {
|
||||
b, _ := json.Marshal(e.RequestTemplate)
|
||||
json.Unmarshal(b, &p.RequestTemplate)
|
||||
}
|
||||
fmt.Printf("parseProtocol: %+v\n", p)
|
||||
// 使用通用解析方法处理各个字段
|
||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||
return p
|
||||
}
|
||||
|
||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||
type ProviderProtocol struct {
|
||||
TargetField string `json:"target_field"`
|
||||
MergeOrder []string `json:"merge_order"`
|
||||
RoleMapping map[string]string `json:"role_mapping"`
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
type ContentMapping struct {
|
||||
Type string `json:"type"` // direct/parts
|
||||
Field string `json:"field"` // content/text
|
||||
}
|
||||
|
||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
||||
if ir == nil || p == nil {
|
||||
return nil, fmt.Errorf("ir and protocol are required")
|
||||
}
|
||||
// 1. 按 merge_order 拼接消息
|
||||
|
||||
messages := mergeByOrder(ir, p.MergeOrder)
|
||||
// 2. 角色映射
|
||||
messages = mapRoles(messages, p.RoleMapping)
|
||||
// 3. 内容字段映射
|
||||
messages = mapContent(messages, p.ContentMapping)
|
||||
// 4. 按 target_field + request_template 构建请求体
|
||||
|
||||
return buildRequest(messages, p, chatModel), nil
|
||||
}
|
||||
|
||||
@@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
||||
if len(mapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
role, ok := msg["role"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if mapped, exists := mapping[role]; exists {
|
||||
messages[i]["role"] = mapped
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
// Gemini 格式: {"parts": [{"text": "..."}]}
|
||||
msg["parts"] = []map[string]any{
|
||||
{cm.Field: content},
|
||||
}
|
||||
default:
|
||||
// direct: {"content": "..."}
|
||||
msg[cm.Field] = content
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
|
||||
if len(p.RequestTemplate) > 0 {
|
||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
p.TargetField: messages,
|
||||
}
|
||||
@@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
|
||||
b, _ := json.Marshal(tmpl)
|
||||
str := string(b)
|
||||
|
||||
// 替换 {{model}}
|
||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
||||
// 替换 {{messages}}
|
||||
|
||||
msgBytes, _ := json.Marshal(messages)
|
||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||
|
||||
var result map[string]any
|
||||
json.Unmarshal([]byte(str), &result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -9,15 +9,16 @@ import (
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ==================== Redis 操作 ====================
|
||||
const (
|
||||
redisKeyPrefix = "chat:session:%s"
|
||||
)
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
||||
key := formatRedisKey(sessionId)
|
||||
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
expireTime := time.Duration(expireSeconds) * time.Second
|
||||
|
||||
data := map[string]any{
|
||||
"sessionId": sessionId,
|
||||
@@ -31,18 +32,29 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
|
||||
if err != nil {
|
||||
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
}
|
||||
|
||||
// executeRedisCommands 执行Redis命令
|
||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||
return fmt.Errorf("写入Redis失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
|
||||
if err != nil {
|
||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -51,7 +63,7 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
||||
key := formatRedisKey(sessionId)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
if err != nil {
|
||||
@@ -62,8 +74,17 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
sessions := parseRedisSessions(ctx, result.Strings())
|
||||
|
||||
reverseSlice(sessions)
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// parseRedisSessions 解析Redis会话数据
|
||||
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
||||
var sessions []map[string]any
|
||||
values := result.Strings()
|
||||
|
||||
for _, str := range values {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||
@@ -73,12 +94,14 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
|
||||
sessions = append(sessions, data)
|
||||
}
|
||||
|
||||
// 反转(Redis 最新在前 → 时间正序)
|
||||
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
|
||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
// reverseSlice 反转切片
|
||||
func reverseSlice(s []map[string]any) {
|
||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||
@@ -92,23 +115,31 @@ func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
return flattenHistoryMessages(historyData), nil
|
||||
}
|
||||
|
||||
// flattenHistoryMessages 扁平化历史消息
|
||||
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, round := range historyData {
|
||||
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
|
||||
for _, m := range reqMsgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
|
||||
for _, m := range respMsgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
appendMessagesFromField(round, "requestContent", &messages)
|
||||
appendMessagesFromField(round, "responseContent", &messages)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendMessagesFromField 从指定字段追加消息
|
||||
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
*messages = append(*messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,112 +3,164 @@ package prompt
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
sessionDao "prompts-core/dao"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"prompts-core/common/util"
|
||||
sessionDto "prompts-core/model/dto/prompt"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
|
||||
// 1. 解析AI返回的文本
|
||||
// SessionCallback 会话回调
|
||||
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新数据库
|
||||
result["role"] = "assistant"
|
||||
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: result,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
|
||||
if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. 获取当前轮次完整数据
|
||||
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
})
|
||||
session, err := getSessionById(ctx, req.EpicycleId)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 转换 json 并存入 Redis
|
||||
if err := saveSessionToRedis(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
||||
session.SessionId, session.Id, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
||||
return &sessionDto.SessionCallbackRes{}, nil
|
||||
|
||||
return &dto.SessionCallbackRes{}, nil
|
||||
}
|
||||
|
||||
// updateSessionResponse 更新会话响应
|
||||
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||
ResponseContent: response,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
|
||||
return fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSessionById 根据ID获取会话
|
||||
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
|
||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
|
||||
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// saveSessionToRedis 保存会话到Redis
|
||||
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
||||
session.SessionId, session.Id, err)
|
||||
return fmt.Errorf("Redis存储失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
|
||||
// 1. 先从 Redis 拿
|
||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
||||
if err == nil && len(redisHistory) > 0 {
|
||||
return redisHistory, nil
|
||||
}
|
||||
|
||||
// 2. Redis 没有 → fallback DB
|
||||
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
|
||||
}
|
||||
|
||||
// getHistoryFromDatabase 从数据库获取历史记录
|
||||
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SessionId: sessionId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
|
||||
messages := extractMessagesFromSessions(sessions)
|
||||
|
||||
cacheSessionsToRedis(ctx, sessions)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// extractMessagesFromSessions 从会话列表中提取消息
|
||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, session := range sessions {
|
||||
// request
|
||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
||||
appendRequestMessages(session.RequestContent, &messages)
|
||||
appendResponseMessages(session.ResponseContent, &messages)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendRequestMessages 追加请求消息
|
||||
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
||||
reqMsgs := util.ConvertToMessages(requestContent)
|
||||
for _, m := range reqMsgs {
|
||||
role := gconv.String(m["role"])
|
||||
if role == "user" || role == "assistant" {
|
||||
messages = append(messages, m)
|
||||
*messages = append(*messages, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// response
|
||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
||||
// appendResponseMessages 追加响应消息
|
||||
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
||||
respMsgs := util.ConvertToMessages(responseContent)
|
||||
for _, m := range respMsgs {
|
||||
if m["role"] == nil {
|
||||
m["role"] = "assistant"
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
*messages = append(*messages, m)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 回写 Redis
|
||||
// cacheSessionsToRedis 将会话缓存到Redis
|
||||
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
||||
for _, session := range sessions {
|
||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
for i := range respMsgs {
|
||||
if respMsgs[i]["role"] == nil {
|
||||
respMsgs[i]["role"] = "assistant"
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
135
service/prompt/prompt_user_form_batches.go
Normal file
135
service/prompt/prompt_user_form_batches.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
||||
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
||||
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
||||
return req, 1, nil
|
||||
}
|
||||
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
|
||||
|
||||
if len(batches) <= 1 {
|
||||
return req, 1, nil
|
||||
}
|
||||
|
||||
newUserForm := buildBatchedUserForm(batches)
|
||||
|
||||
newReq := *req
|
||||
newReq.UserForm = newUserForm
|
||||
|
||||
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
|
||||
len(req.UserForm), len(batches))
|
||||
|
||||
return &newReq, len(batches), nil
|
||||
}
|
||||
|
||||
// buildBatchedUserForm 构建分批后的用户表单
|
||||
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
|
||||
newUserForm := make([]map[string]any, 0, len(batches))
|
||||
|
||||
for i, batch := range batches {
|
||||
combinedText := combineBatchText(batch)
|
||||
newUserForm = append(newUserForm, map[string]any{
|
||||
"batch_index": i + 1,
|
||||
"total_batches": len(batches),
|
||||
"text": combinedText,
|
||||
"item_count": len(batch),
|
||||
})
|
||||
}
|
||||
|
||||
return newUserForm
|
||||
}
|
||||
|
||||
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
|
||||
func combineBatchText(batch []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
|
||||
for j, item := range batch {
|
||||
itemText := getItemText(item)
|
||||
if itemText == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if j > 0 {
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
builder.WriteString(itemText)
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
|
||||
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
|
||||
if len(userForm) == 0 {
|
||||
return [][]map[string]any{}
|
||||
}
|
||||
|
||||
batches := make([][]map[string]any, 0)
|
||||
currentBatch := make([]map[string]any, 0)
|
||||
currentTokens := 0
|
||||
|
||||
for i, item := range userForm {
|
||||
itemText := getItemText(item)
|
||||
itemTokens := util.CalculateTokens(itemText, tokenConfig)
|
||||
|
||||
// 单个元素超过窗口,单独成一批
|
||||
if itemTokens > maxTokens {
|
||||
if len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
currentBatch = make([]map[string]any, 0)
|
||||
currentTokens = 0
|
||||
}
|
||||
batches = append(batches, []map[string]any{item})
|
||||
continue
|
||||
}
|
||||
|
||||
// 判断是否需要新开一批
|
||||
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
currentBatch = make([]map[string]any, 0)
|
||||
currentTokens = 0
|
||||
}
|
||||
|
||||
currentBatch = append(currentBatch, item)
|
||||
currentTokens += itemTokens
|
||||
|
||||
// 最后一批
|
||||
if i == len(userForm)-1 && len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
}
|
||||
}
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
|
||||
func getItemText(item map[string]any) string {
|
||||
if len(item) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for key, value := range item {
|
||||
// 跳过分批时添加的元数据字段
|
||||
if key == "batch_index" || key == "total_batches" || key == "item_count" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%v", value))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
Reference in New Issue
Block a user