refactor(prompt): 重构提示词构建服务与数据模型
This commit is contained in:
72
README.md
72
README.md
@@ -1,30 +1,54 @@
|
|||||||
# prompts-core(提示词服务)[2026.5.12前,暂时弃置]
|
# Prompts-Core 提示词核心服务
|
||||||
|
## 项目简介
|
||||||
|
Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。
|
||||||
|
|
||||||
## 1. 功能范围(当前阶段)
|
## 核心功能
|
||||||
- 仅做提示词配置的基础 CRUD(最小可用版本)
|
1. **提示词构建引擎**
|
||||||
- 表:`prompts_model_prompt`
|
支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。
|
||||||
|
2. **智能会话管理**
|
||||||
|
基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。
|
||||||
|
3. **多模型协议适配**
|
||||||
|
动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。
|
||||||
|
4. **文件与技能集成**
|
||||||
|
自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。
|
||||||
|
5. **异步任务调度**
|
||||||
|
支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。
|
||||||
|
|
||||||
## 2. 接口
|
## 技术架构
|
||||||
> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。
|
- 开发语言:Go 1.26.0
|
||||||
|
- Web 框架:GoFrame v2.10.0
|
||||||
|
- 核心存储:Redis(会话缓存)
|
||||||
|
- 服务组件:Consul(服务注册)、Jaeger(链路追踪)
|
||||||
|
- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型
|
||||||
|
|
||||||
- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema`,`modelName` 作为实际调用的网关模型;结合前端 `form(role/value)` 与 `userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages`
|
## 快速开始
|
||||||
- `GET /composeMessagesCallback/prompts-core`:`model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成)
|
### 环境要求
|
||||||
- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果
|
Go 1.26+、Redis、已部署模型网关服务
|
||||||
- `POST /createPrompt`:创建(默认启用)
|
|
||||||
- `PUT /updatePrompt`:更新
|
|
||||||
- `DELETE /deletePrompt`:删除
|
|
||||||
- `GET /getPrompt`:详情
|
|
||||||
- `POST /listPrompt`:列表分页
|
|
||||||
|
|
||||||
## 3. 数据库初始化
|
### 启动步骤
|
||||||
执行根目录 `update.sql`。
|
1. 克隆项目代码
|
||||||
|
2. 完成项目配置文件修改
|
||||||
|
3. 执行命令启动服务:
|
||||||
|
```bash
|
||||||
|
go run main.go
|
||||||
|
```
|
||||||
|
|
||||||
## 4. 运行配置
|
## API 接口
|
||||||
配置文件:`config.yml`
|
### 基础信息
|
||||||
|
- 服务地址:`http://{host}:3009`
|
||||||
|
- 请求类型:`application/json`
|
||||||
|
- 认证方式:请求头携带 `Authorization`、`X-User`
|
||||||
|
|
||||||
### 新增说明
|
### 核心接口
|
||||||
- `prompts_model_prompt` 去除了 `limit_length`
|
1. **提示词拼接接口**
|
||||||
- 新增 `response_json_schema`
|
- 地址:`POST /composeMessages`
|
||||||
- 新增任务记录表 `prompts_compose_task`
|
- 功能:构建提示词并调用模型服务,同步返回结果
|
||||||
- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback`
|
2. **任务状态查询**
|
||||||
- `model-gateway` 实际回调地址为:`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core`
|
- 地址:`GET /getComposeTask`
|
||||||
|
- 功能:根据任务 ID 查询处理状态与结果
|
||||||
|
3. **任务回调接口**
|
||||||
|
- 地址:`GET /composeMessagesCallback/prompts-core`
|
||||||
|
- 功能:接收模型服务处理完成回调
|
||||||
|
4. **会话同步接口**
|
||||||
|
- 地址:`POST /sessionCallback`
|
||||||
|
- 功能:同步更新会话上下文历史
|
||||||
@@ -8,11 +8,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetModelPrompt 获取请求模型的提示词
|
// GetModelPrompt 获取请求模型的提示词
|
||||||
func GetModelPrompt(ctx context.Context, Type int) string {
|
func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||||
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
|
key := "modelPrompts.types." + gconv.String(modelType)
|
||||||
|
return g.Cfg().MustGet(ctx, key, "").String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetBuildPrompt 获取构建提示词
|
// GetBuildPrompt 获取节点构建提示词
|
||||||
func GetBuildPrompt(ctx context.Context, Type int) string {
|
func GetBuildPrompt(ctx context.Context) string {
|
||||||
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
|
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,38 +6,41 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
|
var (
|
||||||
var AllowedMIMEPrefixes = []string{
|
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
|
||||||
"text/",
|
AllowedMIMEPrefixes = []string{
|
||||||
"application/json",
|
"text/",
|
||||||
"application/xml",
|
"application/json",
|
||||||
"application/javascript",
|
"application/xml",
|
||||||
"application/x-yaml",
|
"application/javascript",
|
||||||
"application/yaml",
|
"application/x-yaml",
|
||||||
"application/toml",
|
"application/yaml",
|
||||||
"application/x-httpd-php",
|
"application/toml",
|
||||||
"application/x-sh",
|
"application/x-httpd-php",
|
||||||
"application/x-python",
|
"application/x-sh",
|
||||||
"application/x-perl",
|
"application/x-python",
|
||||||
"application/x-ruby",
|
"application/x-perl",
|
||||||
}
|
"application/x-ruby",
|
||||||
|
}
|
||||||
|
|
||||||
// BannedExtensions 禁止的文件扩展名
|
// BannedExtensions 禁止的文件扩展名
|
||||||
var BannedExtensions = map[string]bool{
|
BannedExtensions = map[string]bool{
|
||||||
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
||||||
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
||||||
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
||||||
".wma": true, ".m4a": true,
|
".wma": true, ".m4a": true,
|
||||||
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
|
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
|
||||||
".flv": true, ".webm": true,
|
".flv": true, ".webm": true,
|
||||||
".tar": true, ".gz": true, ".rar": true, ".7z": true,
|
".tar": true, ".gz": true, ".rar": true, ".7z": true,
|
||||||
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
|
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
|
||||||
".class": true, ".pyc": true,
|
".class": true, ".pyc": true,
|
||||||
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
||||||
".ppt": true, ".pptx": true,
|
".ppt": true, ".pptx": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||||
|
multiNewlines = regexp.MustCompile(`\n{3,}`)
|
||||||
|
)
|
||||||
|
|
||||||
// SanitizeURL 清洗 URL 字符串
|
// SanitizeURL 清洗 URL 字符串
|
||||||
func SanitizeURL(raw string) string {
|
func SanitizeURL(raw string) string {
|
||||||
@@ -51,25 +54,19 @@ func CleanSymbols(text string) string {
|
|||||||
text = symbolCleaner.ReplaceAllString(text, "")
|
text = symbolCleaner.ReplaceAllString(text, "")
|
||||||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||||
text = strings.ReplaceAll(text, "\r", "\n")
|
text = strings.ReplaceAll(text, "\r", "\n")
|
||||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
text = multiNewlines.ReplaceAllString(text, "\n\n")
|
||||||
return strings.TrimSpace(text)
|
return strings.TrimSpace(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBannedExtension 判断是否为禁止的文件扩展名
|
// IsBannedExtension 判断是否为禁止的文件扩展名
|
||||||
func IsBannedExtension(url string) bool {
|
func IsBannedExtension(url string) bool {
|
||||||
ext := strings.ToLower(filepath.Ext(url))
|
ext := extractExtension(url)
|
||||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
|
||||||
ext = ext[:idx]
|
|
||||||
}
|
|
||||||
return BannedExtensions[ext]
|
return BannedExtensions[ext]
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZipExtension 判断是否为 zip 文件
|
// IsZipExtension 判断是否为 zip 文件
|
||||||
func IsZipExtension(url string) bool {
|
func IsZipExtension(url string) bool {
|
||||||
ext := strings.ToLower(filepath.Ext(url))
|
ext := extractExtension(url)
|
||||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
|
||||||
ext = ext[:idx]
|
|
||||||
}
|
|
||||||
return ext == ".zip"
|
return ext == ".zip"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,11 +75,22 @@ func IsReadableContentType(contentType string) bool {
|
|||||||
if contentType == "" {
|
if contentType == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
ct := strings.ToLower(contentType)
|
ct := strings.ToLower(contentType)
|
||||||
for _, prefix := range AllowedMIMEPrefixes {
|
for _, prefix := range AllowedMIMEPrefixes {
|
||||||
if strings.HasPrefix(ct, prefix) {
|
if strings.HasPrefix(ct, prefix) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractExtension 提取文件扩展名并清理查询参数
|
||||||
|
func extractExtension(url string) string {
|
||||||
|
ext := strings.ToLower(filepath.Ext(url))
|
||||||
|
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||||
|
ext = ext[:idx]
|
||||||
|
}
|
||||||
|
return ext
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||||
func AsyncCtx(ctx context.Context) context.Context {
|
func AsyncCtx(ctx context.Context) context.Context {
|
||||||
asyncCtx := context.WithoutCancel(ctx)
|
asyncCtx := context.WithoutCancel(ctx)
|
||||||
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
if token := r.Header.Get("Authorization"); token != "" {
|
if token := r.Header.Get("Authorization"); token != "" {
|
||||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||||
@@ -18,9 +19,11 @@ func AsyncCtx(ctx context.Context) context.Context {
|
|||||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||||
}
|
}
|
||||||
|
|
||||||
return asyncCtx
|
return asyncCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -28,25 +31,37 @@ func AsyncCtx(ctx context.Context) context.Context {
|
|||||||
func ForwardHeaders(ctx context.Context) map[string]string {
|
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
|
|
||||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
setHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||||
headers["Authorization"] = token
|
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||||
}
|
|
||||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
fallbackToRequestHeaders(headers, ctx)
|
||||||
headers["X-User-Info"] = x
|
|
||||||
}
|
|
||||||
|
|
||||||
// 兜底:从请求头获取
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
if headers["Authorization"] == "" {
|
|
||||||
if token := r.Header.Get("Authorization"); token != "" {
|
|
||||||
headers["Authorization"] = token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if headers["X-User-Info"] == "" {
|
|
||||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
|
||||||
headers["X-User-Info"] = userInfo
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setHeaderFromContext 从上下文中设置 header
|
||||||
|
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||||
|
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||||
|
headers[headerKey] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallbackToRequestHeaders 从请求头中获取作为兜底
|
||||||
|
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||||
|
r := g.RequestFromCtx(ctx)
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers["Authorization"] == "" {
|
||||||
|
if token := r.Header.Get("Authorization"); token != "" {
|
||||||
|
headers["Authorization"] = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers["X-User-Info"] == "" {
|
||||||
|
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||||
|
headers["X-User-Info"] = userInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func ParseOutput(text string) (map[string]any, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return j.Map(), nil
|
return j.Map(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,26 +24,17 @@ func ConvertToMessages(raw any) []map[string]any {
|
|||||||
if raw == nil {
|
if raw == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// 如果有 messages 字段,直接返回
|
|
||||||
if j.Contains("messages") {
|
if j.Contains("messages") {
|
||||||
return gconv.Maps(j.Get("messages").Array())
|
return gconv.Maps(j.Get("messages").Array())
|
||||||
}
|
}
|
||||||
// 否则当成单条 message
|
|
||||||
return []map[string]any{
|
|
||||||
j.Map(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsMessageValid 校验推理结果是否合法
|
return []map[string]any{j.Map()}
|
||||||
func IsMessageValid(message map[string]any) bool {
|
|
||||||
if message == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||||
@@ -50,6 +42,17 @@ func FormToJSON(form map[string]any) string {
|
|||||||
if form == nil {
|
if form == nil {
|
||||||
return "{}"
|
return "{}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b, _ := json.Marshal(form)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
||||||
|
func UserFormToJSON(form []map[string]any) string {
|
||||||
|
if form == nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
|
||||||
b, _ := json.Marshal(form)
|
b, _ := json.Marshal(form)
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
@@ -60,39 +63,16 @@ func MustMarshal(v any) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "{}"
|
return "{}"
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseJSONField 解析 JSON 字段
|
|
||||||
func ParseJSONField(field any) any {
|
|
||||||
var v *gvar.Var
|
|
||||||
switch val := field.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
v = val
|
|
||||||
default:
|
|
||||||
return field
|
|
||||||
}
|
|
||||||
|
|
||||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
str := v.String()
|
|
||||||
var result any
|
|
||||||
if json.Unmarshal([]byte(str), &result) == nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
||||||
func JSONPretty(v any) string {
|
func JSONPretty(v any) string {
|
||||||
// 处理 *gvar.Var 类型
|
|
||||||
if gv, ok := v.(*gvar.Var); ok {
|
if gv, ok := v.(*gvar.Var); ok {
|
||||||
v = gconv.Map(gv.String())
|
v = gconv.Map(gv.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统一转 map 再美化
|
|
||||||
var tmp map[string]any
|
var tmp map[string]any
|
||||||
if err := gconv.Struct(v, &tmp); err != nil {
|
if err := gconv.Struct(v, &tmp); err != nil {
|
||||||
return gconv.String(v)
|
return gconv.String(v)
|
||||||
@@ -101,3 +81,71 @@ func JSONPretty(v any) string {
|
|||||||
b, _ := json.MarshalIndent(tmp, "", " ")
|
b, _ := json.MarshalIndent(tmp, "", " ")
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
|
||||||
|
func GvarToMap(v *gvar.Var) map[string]any {
|
||||||
|
if v == nil || v.IsNil() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make(map[string]any)
|
||||||
|
|
||||||
|
// 方法1:尝试获取 map 值
|
||||||
|
if m := v.Map(); len(m) > 0 {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方法2:尝试解析 JSON 字符串
|
||||||
|
str := v.String()
|
||||||
|
if str != "" && str != "<nil>" {
|
||||||
|
json.Unmarshal([]byte(str), &result)
|
||||||
|
if len(result) > 0 {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方法3:尝试获取 interface 再转换
|
||||||
|
if val := v.Val(); val != nil {
|
||||||
|
switch val.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
return val.(map[string]any)
|
||||||
|
default:
|
||||||
|
data, _ := json.Marshal(val)
|
||||||
|
json.Unmarshal(data, &result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
||||||
|
func ParseJSONFieldFromGvar(source any, target any) {
|
||||||
|
if source == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := source.(type) {
|
||||||
|
case *gvar.Var:
|
||||||
|
if v.IsNil() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试获取 map
|
||||||
|
if m := v.Map(); len(m) > 0 {
|
||||||
|
data, _ := json.Marshal(m)
|
||||||
|
json.Unmarshal(data, target)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析 JSON 字符串
|
||||||
|
str := v.String()
|
||||||
|
if str != "" && str != "<nil>" {
|
||||||
|
json.Unmarshal([]byte(str), target)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 其他类型走原来的逻辑
|
||||||
|
data, _ := json.Marshal(source)
|
||||||
|
json.Unmarshal(data, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
229
common/util/token.go
Normal file
229
common/util/token.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/container/gvar"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
|
||||||
|
punctRegex = regexp.MustCompile(`[[:punct:]]`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenConfig Token计算配置
|
||||||
|
type TokenConfig struct {
|
||||||
|
ZhRatio float64 `json:"zh_ratio"`
|
||||||
|
EnRatio float64 `json:"en_ratio"`
|
||||||
|
SpaceRatio float64 `json:"space_ratio"`
|
||||||
|
PunctuationRatio float64 `json:"punctuation_ratio"`
|
||||||
|
MaxWindowSize int `json:"max_window_size"`
|
||||||
|
ReserveRatio float64 `json:"reserve_ratio"`
|
||||||
|
MinReserve int `json:"min_reserve"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateTokens 计算文本token数
|
||||||
|
func CalculateTokens(text string, tokenConfig any) int {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
zhCount := countChineseChars(text)
|
||||||
|
enCount := countEnglishWords(text)
|
||||||
|
spaceCount := strings.Count(text, " ")
|
||||||
|
punctCount := countPunctuation(text)
|
||||||
|
|
||||||
|
totalTokens := int(
|
||||||
|
float64(zhCount)*config.ZhRatio +
|
||||||
|
float64(enCount)*config.EnRatio +
|
||||||
|
float64(spaceCount)*config.SpaceRatio +
|
||||||
|
float64(punctCount)*config.PunctuationRatio,
|
||||||
|
)
|
||||||
|
|
||||||
|
return totalTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountToken 计算token是否超出窗口限制
|
||||||
|
// 返回: true - 未超出(可用), false - 已超出(不可用)
|
||||||
|
func CountToken(text string, tokenConfig any) bool {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
estimatedTokens := CalculateTokens(text, tokenConfig)
|
||||||
|
availableWindow := GetAvailableWindow(tokenConfig)
|
||||||
|
|
||||||
|
return estimatedTokens <= availableWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvailableWindow 获取可用窗口大小
|
||||||
|
func GetAvailableWindow(tokenConfig any) int {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil {
|
||||||
|
return 4096
|
||||||
|
}
|
||||||
|
|
||||||
|
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
|
||||||
|
reserve := reserveByRatio
|
||||||
|
|
||||||
|
if config.MinReserve > reserve {
|
||||||
|
reserve = config.MinReserve
|
||||||
|
}
|
||||||
|
|
||||||
|
available := config.MaxWindowSize - reserve
|
||||||
|
if available < 0 {
|
||||||
|
available = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxWindowSize 获取模型最大窗口大小
|
||||||
|
func GetMaxWindowSize(tokenConfig any) int {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil {
|
||||||
|
return 4096
|
||||||
|
}
|
||||||
|
|
||||||
|
return config.MaxWindowSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
|
||||||
|
// 返回: isValid, exceedTokens, error
|
||||||
|
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil || len(userForm) == 0 {
|
||||||
|
return true, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
|
||||||
|
availableWindow := GetAvailableWindow(tokenConfig)
|
||||||
|
|
||||||
|
if totalTokens > availableWindow {
|
||||||
|
return false, totalTokens - availableWindow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
|
||||||
|
// 返回: 需要拆分的批次数, 每批的token数, 错误
|
||||||
|
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
|
||||||
|
config := parseConfig(tokenConfig)
|
||||||
|
if config == nil || len(userForm) == 0 {
|
||||||
|
return 1, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
availableWindow := GetAvailableWindow(tokenConfig)
|
||||||
|
|
||||||
|
batches := 1
|
||||||
|
currentTokens := 0
|
||||||
|
batchTokens := make([]int, 0)
|
||||||
|
|
||||||
|
for _, item := range userForm {
|
||||||
|
itemStr := fmt.Sprintf("%v", item)
|
||||||
|
itemTokens := CalculateTokens(itemStr, tokenConfig)
|
||||||
|
|
||||||
|
if currentTokens+itemTokens > availableWindow {
|
||||||
|
batchTokens = append(batchTokens, currentTokens)
|
||||||
|
batches++
|
||||||
|
currentTokens = itemTokens
|
||||||
|
} else {
|
||||||
|
currentTokens += itemTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentTokens > 0 {
|
||||||
|
batchTokens = append(batchTokens, currentTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return batches, batchTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseConfig 解析配置
|
||||||
|
func parseConfig(tokenConfig any) *TokenConfig {
|
||||||
|
if tokenConfig == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := tokenConfig.(type) {
|
||||||
|
case *gvar.Var:
|
||||||
|
return parseGVarConfig(v)
|
||||||
|
case map[string]any:
|
||||||
|
return parseMapConfig(v)
|
||||||
|
case *TokenConfig:
|
||||||
|
return v
|
||||||
|
case TokenConfig:
|
||||||
|
return &v
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseGVarConfig 解析 GVar 配置
|
||||||
|
func parseGVarConfig(v *gvar.Var) *TokenConfig {
|
||||||
|
if v.IsNil() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mapVal := v.Map()
|
||||||
|
if mapVal == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &TokenConfig{}
|
||||||
|
data, _ := json.Marshal(mapVal)
|
||||||
|
json.Unmarshal(data, config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMapConfig 解析 Map 配置
|
||||||
|
func parseMapConfig(v map[string]any) *TokenConfig {
|
||||||
|
config := &TokenConfig{}
|
||||||
|
data, _ := json.Marshal(v)
|
||||||
|
json.Unmarshal(data, config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// countChineseChars 统计中文字符数量
|
||||||
|
func countChineseChars(text string) int {
|
||||||
|
count := 0
|
||||||
|
for _, r := range text {
|
||||||
|
if unicode.Is(unicode.Han, r) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// countEnglishWords 统计英文单词数量
|
||||||
|
func countEnglishWords(text string) int {
|
||||||
|
return len(enWordRegex.FindAllString(text, -1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// countPunctuation 统计标点符号数量
|
||||||
|
func countPunctuation(text string) int {
|
||||||
|
return len(punctRegex.FindAllString(text, -1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateUserFormTokens 计算 UserForm 总 token 数
|
||||||
|
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
|
||||||
|
totalTokens := 0
|
||||||
|
for _, item := range userForm {
|
||||||
|
itemStr := fmt.Sprintf("%v", item)
|
||||||
|
totalTokens += CalculateTokens(itemStr, tokenConfig)
|
||||||
|
}
|
||||||
|
return totalTokens
|
||||||
|
}
|
||||||
100
config.yml
100
config.yml
@@ -103,55 +103,51 @@ modelPrompts:
|
|||||||
在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。
|
在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。
|
||||||
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
|
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
|
||||||
|
|
||||||
buildProject:
|
nodePrompts: |
|
||||||
types:
|
你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。
|
||||||
1: |
|
规则:
|
||||||
你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
1. 只允许从下面的可选节点ID列表中选择一个返回
|
||||||
【强制规则】
|
2. 不要返回任何多余文字、标点、解释、标题
|
||||||
必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
3. 只返回纯节点ID
|
||||||
完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
可选节点ID(ID: 节点描述):
|
||||||
完整读取UserForm所有字段,不得忽略任何字段;
|
%s
|
||||||
如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
上下文内容:
|
||||||
理解全部语义后再输出,禁止断章取义;
|
%s
|
||||||
UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
|
||||||
【优先级】
|
#你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
||||||
用户自然语言 > UserForm > Form;
|
# 【强制规则】
|
||||||
UserForm与Form同名字段时,仅保留UserForm值;
|
# 必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
||||||
Form仅用于组装system角色内容。
|
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
||||||
【表单处理】
|
# 完整读取UserForm所有字段,不得忽略任何字段;
|
||||||
Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
||||||
UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
# 理解全部语义后再输出,禁止断章取义;
|
||||||
自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和(示例:10条文案各配5张图 → 总50张,parameters.n=50),用户没有相关数量必须默认1;
|
# UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
||||||
图片尺寸为空时自动填充size=1024*1024。
|
# 【优先级】
|
||||||
【结构铁律】
|
# 用户自然语言 > UserForm > Form;
|
||||||
严格沿用固定输出结构,不增删字段或修改层级;
|
# UserForm与Form同名字段时,仅保留UserForm值;
|
||||||
messages元素必须按结构返回;
|
# Form仅用于组装system角色内容。
|
||||||
禁止将role对象转为字符串、禁止嵌套错乱;
|
# 【表单处理】
|
||||||
输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
# Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
||||||
所有括号、引号必须成对闭合,保证JSON合法。
|
# UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
||||||
【参数赋值】
|
# 自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和,用户没有相关数量必须默认1;
|
||||||
model固定沿用传入值;
|
# 图片尺寸为空时自动填充size=1024*1024。
|
||||||
返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
# 【结构铁律】
|
||||||
history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
# 严格沿用固定输出结构,不增删字段或修改层级;
|
||||||
从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
# messages元素必须按结构返回;
|
||||||
【输出要求】
|
# 禁止将role对象转为字符串、禁止嵌套错乱;
|
||||||
仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
# 输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
||||||
完整合UserForm全部字段语义到user描述;
|
# 所有括号、引号必须成对闭合,保证JSON合法。
|
||||||
生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
# 【参数赋值】
|
||||||
【输出结构】
|
# model固定沿用传入值;
|
||||||
%s
|
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
||||||
【字段映射】
|
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
||||||
%s
|
# 从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
||||||
【完整输入信息】
|
# 【输出要求】
|
||||||
%s
|
# 仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
||||||
直接输出最终JSON:
|
# 完整合UserForm全部字段语义到user描述;
|
||||||
2: |
|
# 生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
||||||
你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。
|
# 【输出结构】
|
||||||
规则:
|
# %s
|
||||||
1. 只允许从下面的可选节点ID列表中选择一个返回
|
# 【完整输入信息】
|
||||||
2. 不要返回任何多余文字、标点、解释、标题
|
# %s
|
||||||
3. 只返回纯节点ID
|
# 直接输出最终JSON:
|
||||||
可选节点ID(ID: 节点描述):
|
|
||||||
%s
|
|
||||||
上下文内容:
|
|
||||||
%s
|
|
||||||
@@ -5,3 +5,8 @@ const (
|
|||||||
ComposeStatusSuccess = "success"
|
ComposeStatusSuccess = "success"
|
||||||
ComposeStatusFailed = "failed"
|
ComposeStatusFailed = "failed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BuildTypePrompt = 1 //提示词构建
|
||||||
|
BuildTypeNode = 2 //节点构建
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package prompt
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
|
||||||
promptDto "prompts-core/model/dto/prompt"
|
|
||||||
promptService "prompts-core/service/prompt"
|
promptService "prompts-core/service/prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,17 +13,17 @@ type prompt struct{}
|
|||||||
var Prompt = new(prompt)
|
var Prompt = new(prompt)
|
||||||
|
|
||||||
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
|
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
|
||||||
func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) {
|
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
|
||||||
return promptService.ComposeMessages(ctx, req)
|
return promptService.ComposeMessages(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback model-gateway 提示词回调
|
// Callback model-gateway 提示词回调
|
||||||
func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) {
|
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) {
|
||||||
err = promptService.Callback(ctx, req)
|
err = promptService.Callback(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetComposeTask 查询拼接任务结果
|
// GetComposeTask 查询拼接任务结果
|
||||||
func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) {
|
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||||
}
|
}
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
package prompt
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
|
||||||
promptDto "prompts-core/model/dto/prompt"
|
|
||||||
promptService "prompts-core/service/prompt"
|
promptService "prompts-core/service/prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -13,6 +13,6 @@ type session struct{}
|
|||||||
var Session = new(session)
|
var Session = new(session)
|
||||||
|
|
||||||
// SessionCallback 会话回调
|
// SessionCallback 会话回调
|
||||||
func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) {
|
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||||
return promptService.SessionCallback(ctx, req)
|
return promptService.SessionCallback(ctx, req)
|
||||||
}
|
}
|
||||||
@@ -15,7 +15,7 @@ type composeSessionDao struct{}
|
|||||||
|
|
||||||
// Insert 插入
|
// Insert 插入
|
||||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||||
var m = new(entity.ComposeTask)
|
var m = new(entity.ComposeSession)
|
||||||
err = gconv.Struct(req, &m)
|
err = gconv.Struct(req, &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ProviderProtocol = &providerProtocolDao{}
|
var ProviderProtocol = &providerProtocolDao{}
|
||||||
@@ -14,7 +15,13 @@ type providerProtocolDao struct{}
|
|||||||
|
|
||||||
// Insert 新增协议配置
|
// Insert 新增协议配置
|
||||||
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
|
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert()
|
var m = new(entity.ProviderProtocol)
|
||||||
|
err = gconv.Struct(req, &m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||||
|
Insert(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"prompts-core/controller/prompt"
|
"prompts-core/controller"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.com/red-future/common/http"
|
||||||
@@ -21,8 +21,8 @@ func main() {
|
|||||||
defer jaeger.ShutDown(ctx)
|
defer jaeger.ShutDown(ctx)
|
||||||
// 注册路由
|
// 注册路由
|
||||||
http.RouteRegister([]interface{}{
|
http.RouteRegister([]interface{}{
|
||||||
prompt.Prompt,
|
controller.Prompt,
|
||||||
prompt.Session,
|
controller.Session,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
|
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
|
||||||
|
|||||||
@@ -1,22 +1,28 @@
|
|||||||
package prompt
|
package dto
|
||||||
|
|
||||||
import "github.com/gogf/gf/v2/frame/g"
|
import "github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
type ComposeMessagesReq struct {
|
type ComposeMessagesReq struct {
|
||||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
||||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||||
UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||||
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
|
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ComposeMessagesRes struct {
|
type ComposeMessagesRes struct {
|
||||||
Messages any `json:"messages,omitempty" dc:"最终消息数组"`
|
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiRoundResult 多轮返回结果
|
||||||
|
type MultiRoundResult struct {
|
||||||
|
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||||
|
Rounds []any `json:"rounds"` // 每轮详情(动态类型)
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackReq struct {
|
type CallbackReq struct {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package prompt
|
package dto
|
||||||
|
|
||||||
import "github.com/gogf/gf/v2/frame/g"
|
import "github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
@@ -16,10 +16,10 @@ type AsynchModel struct {
|
|||||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
||||||
Prompt string `orm:"prompt" json:"prompt"`
|
Prompt string `orm:"prompt" json:"prompt"`
|
||||||
IsPrivate int `orm:"is_private" json:"isPrivate"`
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||||
Enabled int `orm:"enabled" json:"enabled"`
|
Enabled *int `orm:"enabled" json:"enabled"`
|
||||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||||
@@ -28,6 +28,9 @@ type AsynchModel struct {
|
|||||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
Remark string `orm:"remark" json:"remark"`
|
||||||
|
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||||
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
|
TokenConfig any `orm:"token_config" json:"tokenConfig"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type asynchModelCol struct {
|
type asynchModelCol struct {
|
||||||
@@ -55,6 +58,9 @@ type asynchModelCol struct {
|
|||||||
RetryQueueMaxSecs string
|
RetryQueueMaxSecs string
|
||||||
AutoCleanSeconds string
|
AutoCleanSeconds string
|
||||||
Remark string
|
Remark string
|
||||||
|
IsOwner string
|
||||||
|
OperatorName string
|
||||||
|
TokenConfig string
|
||||||
}
|
}
|
||||||
|
|
||||||
var AsynchModelCol = asynchModelCol{
|
var AsynchModelCol = asynchModelCol{
|
||||||
@@ -82,4 +88,7 @@ var AsynchModelCol = asynchModelCol{
|
|||||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||||
AutoCleanSeconds: "auto_clean_seconds",
|
AutoCleanSeconds: "auto_clean_seconds",
|
||||||
Remark: "remark",
|
Remark: "remark",
|
||||||
|
IsOwner: "is_owner",
|
||||||
|
OperatorName: "operator_name",
|
||||||
|
TokenConfig: "token_config",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,65 +4,113 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"prompts-core/consts/public"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
"prompts-core/model/dto/prompt"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// buildInferenceRequest 构建返回请求
|
// buildInferenceRequest 构建推理请求
|
||||||
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||||
|
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
ir := NewPromptIR()
|
ir := NewPromptIR()
|
||||||
// 1. 统一 Prompt IR
|
|
||||||
switch req.BuildType {
|
switch req.BuildType {
|
||||||
case 1: //构建提示词请求
|
case public.BuildTypePrompt:
|
||||||
ir.AddSystem(promptBuild(ctx, req, model))
|
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches)
|
||||||
for _, msg := range history {
|
case public.BuildTypeNode:
|
||||||
role := gconv.String(msg["role"])
|
return buildNodeTypeRequest(ctx, req, ir)
|
||||||
if role != "user" && role != "assistant" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
|
||||||
}
|
|
||||||
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
|
|
||||||
case 2: //构建节点请求
|
|
||||||
ir.AddUser(NodeBuild(ctx, req))
|
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("不支持的构建类型")
|
return nil, errors.New("不支持的构建类型")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 获取协议配置
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
protocol, err := GetProtocolByProvider(ctx, "qwen")
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||||
|
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
|
||||||
|
ir.AddSystem(systemPrompt)
|
||||||
|
|
||||||
|
for _, msg := range history {
|
||||||
|
role := gconv.String(msg["role"])
|
||||||
|
if role != "user" && role != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ir.AddHistory(role, gconv.String(msg["content"]))
|
||||||
|
}
|
||||||
|
|
||||||
|
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
|
||||||
|
ir.AddUser(userPrompt)
|
||||||
|
|
||||||
|
if !checkOverallContent(ir, targetModel) {
|
||||||
|
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
|
||||||
|
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||||
|
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) {
|
||||||
|
ir.AddUser(NodeBuild(ctx, req))
|
||||||
|
|
||||||
|
protocol, err := GetProtocolByProvider(ctx, req.ModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||||
}
|
}
|
||||||
if protocol == nil {
|
if protocol == nil {
|
||||||
return nil, errors.New("协议配置不存在")
|
return nil, errors.New("协议配置不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 编译为 Provider Request
|
providerReq, err := Compile(ir, protocol, nil)
|
||||||
providerReq, err := Compile(ir, protocol, chatModel)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 构建请求体
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"modelName": chatModel.ModelName,
|
"modelName": req.ModelName,
|
||||||
"bizName": "prompts-core",
|
"bizName": "prompts-core",
|
||||||
"callbackUrl": "/prompt/callback",
|
"callbackUrl": "/prompt/callback",
|
||||||
"requestPayload": providerReq,
|
"requestPayload": providerReq,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptBuild 构建系统提示词
|
// compileToProviderRequest 编译为 Provider 请求
|
||||||
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
|
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) {
|
||||||
|
protocol, err := GetProtocolByProvider(ctx, providerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||||
|
}
|
||||||
|
if protocol == nil {
|
||||||
|
return nil, errors.New("协议配置不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerReq, err := Compile(ir, protocol, model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
|
||||||
|
return map[string]any{
|
||||||
|
"modelName": model.ModelName,
|
||||||
|
"bizName": "prompts-core",
|
||||||
|
"callbackUrl": "/prompt/callback",
|
||||||
|
"requestPayload": providerReq,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
|
||||||
|
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: "qwen",
|
ProviderName: model.OperatorName,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
})
|
})
|
||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
@@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent
|
|||||||
}
|
}
|
||||||
|
|
||||||
outputJSON := util.JSONPretty(model.RequestMapping)
|
outputJSON := util.JSONPretty(model.RequestMapping)
|
||||||
var userFormContent strings.Builder
|
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
|
||||||
for k, v := range req.UserForm {
|
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||||
userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v))
|
|
||||||
}
|
|
||||||
userFormFullText := strings.TrimSuffix(userFormContent.String(), ";")
|
|
||||||
|
|
||||||
|
userFormContent := buildUserFormContent(req.UserForm)
|
||||||
formInfo := fmt.Sprintf(`
|
formInfo := fmt.Sprintf(`
|
||||||
【系统表单(系统提示词/参数)】
|
【系统表单(系统提示词/参数)】
|
||||||
%s
|
%s
|
||||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||||||
%s
|
%s
|
||||||
`, util.FormToJSON(req.Form), userFormFullText)
|
`, util.FormToJSON(req.Form), userFormContent)
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
|
inputInfo := fmt.Sprintf(`
|
||||||
|
目标模型: %s
|
||||||
|
%s
|
||||||
|
技能名称: %s
|
||||||
|
用户文件: %v
|
||||||
|
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
|
||||||
|
|
||||||
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||||
|
req.ModelName,
|
||||||
|
maxWindowSize,
|
||||||
|
availableWindow,
|
||||||
|
totalRounds,
|
||||||
|
totalRounds,
|
||||||
|
totalRounds,
|
||||||
|
outputJSON,
|
||||||
|
inputInfo,
|
||||||
|
totalRounds,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建用户提示词
|
// buildUserFormContent 构建用户表单内容字符串
|
||||||
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
|
func buildUserFormContent(userForm []map[string]any) string {
|
||||||
payload := map[string]any{
|
var builder strings.Builder
|
||||||
"model": req.ModelName, // 请求模型名称
|
for _, item := range userForm {
|
||||||
"promptInfo": prompt, // 数据库提示信息
|
builder.WriteString(fmt.Sprintf("%v\n", item))
|
||||||
"form": req.Form, // 系统表单
|
|
||||||
"userForm": req.UserForm, // 用户表单
|
|
||||||
"userFiles": req.UserFiles, //文件url
|
|
||||||
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如:xml,json,yaml)
|
|
||||||
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
|
|
||||||
}
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
|
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
|
||||||
|
fullContent := ir.String()
|
||||||
|
return util.CountToken(fullContent, model.TokenConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUserPrompt 构建用户提示词
|
||||||
|
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||||
|
userFormForPayload := prepareUserFormPayload(req.UserForm)
|
||||||
|
|
||||||
|
payload := map[string]any{
|
||||||
|
"model": req.ModelName,
|
||||||
|
"promptInfo": prompt,
|
||||||
|
"form": req.Form,
|
||||||
|
"userForm": userFormForPayload,
|
||||||
|
"userFiles": req.UserFiles,
|
||||||
|
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
|
||||||
|
"skills": SkillMdContent(ctx, req.SkillName),
|
||||||
|
}
|
||||||
|
|
||||||
return util.MustMarshal(payload)
|
return util.MustMarshal(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prepareUserFormPayload 准备用户表单载荷
|
||||||
|
func prepareUserFormPayload(userForm []map[string]any) any {
|
||||||
|
if len(userForm) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := userForm[0]["batch_index"]; ok {
|
||||||
|
return userForm
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergeUserFormTexts(userForm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
|
||||||
|
func mergeUserFormTexts(userForm []map[string]any) string {
|
||||||
|
var builder strings.Builder
|
||||||
|
for i, item := range userForm {
|
||||||
|
text := getItemText(item)
|
||||||
|
if i > 0 {
|
||||||
|
builder.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(text)
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
// NodeBuild 节点构建
|
// NodeBuild 节点构建
|
||||||
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
|
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||||
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
|
promptTpl := util.GetBuildPrompt(ctx)
|
||||||
if promptTpl == "" {
|
if promptTpl == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
formStr := util.FormToJSON(req.Form)
|
formStr := util.FormToJSON(req.Form)
|
||||||
userFormStr := util.FormToJSON(req.UserForm)
|
userFormStr := util.UserFormToJSON(req.UserForm)
|
||||||
|
|
||||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,171 +5,229 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/dao"
|
|
||||||
"prompts-core/model/entity"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
|
||||||
"prompts-core/consts/public"
|
|
||||||
promptDto "prompts-core/model/dto/prompt"
|
|
||||||
"prompts-core/service/gateway"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
"gitea.com/red-future/common/beans"
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
"github.com/gogf/gf/v2/container/gvar"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/consts/public"
|
||||||
|
"prompts-core/dao"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
"prompts-core/model/entity"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ComposeMessages 核心拼接提示词主流程
|
// ComposeMessages 核心拼接提示词主流程
|
||||||
func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
|
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||||
var (
|
|
||||||
epicycleId int64
|
|
||||||
taskID string
|
|
||||||
history []map[string]any
|
|
||||||
message map[string]any
|
|
||||||
err error
|
|
||||||
taskRecord *entity.ComposeTask
|
|
||||||
)
|
|
||||||
// 获取模型信息
|
|
||||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 根据构建类型进行判断处理
|
if err = validateUserForm(ctx, req, aiModel); err != nil {
|
||||||
switch req.BuildType {
|
return nil, err
|
||||||
//提示词构建
|
|
||||||
case 1:
|
|
||||||
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
|
|
||||||
//1. 获取历史会话
|
|
||||||
history, err = GetHistoryMessages(ctx, req.SessionId)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
|
||||||
history = nil // 出错就用空的,不影响主流程
|
|
||||||
}
|
|
||||||
// 重试循环
|
|
||||||
for attempt := 0; attempt <= 0; attempt++ {
|
|
||||||
if attempt > 0 {
|
|
||||||
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
|
|
||||||
}
|
|
||||||
// 2. 调用推理模型
|
|
||||||
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 保存记录
|
|
||||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
SkillName: req.SkillName,
|
|
||||||
RequestPayload: util.MustMarshal(req),
|
|
||||||
Status: public.ComposeStatusPending,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 等待结果
|
|
||||||
taskRecord, err = waitForResult(ctx, taskID)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 校验结果
|
|
||||||
message = parsePromptBuild(taskRecord, chatModel)
|
|
||||||
if message != nil && util.IsMessageValid(message) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
|
|
||||||
message = nil
|
|
||||||
}
|
|
||||||
if message == nil {
|
|
||||||
return nil, errors.New("推理模型调用失败,请稍后再试")
|
|
||||||
}
|
|
||||||
//5.创建会话记录
|
|
||||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
|
||||||
SessionId: req.SessionId,
|
|
||||||
RequestContent: message,
|
|
||||||
})
|
|
||||||
//节点构建
|
|
||||||
case 2:
|
|
||||||
//1. 调用推理模型
|
|
||||||
taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
//2. 保存相关记录
|
|
||||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
SkillName: req.SkillName,
|
|
||||||
RequestPayload: util.MustMarshal(req),
|
|
||||||
Status: public.ComposeStatusPending,
|
|
||||||
})
|
|
||||||
//5. 等待结果
|
|
||||||
taskRecord, err := waitForResult(ctx, taskID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
message = parseNodeBuild(taskRecord)
|
|
||||||
default:
|
|
||||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
|
||||||
SessionId: req.SessionId,
|
|
||||||
Remark: req.Cause,
|
|
||||||
})
|
|
||||||
return &promptDto.ComposeMessagesRes{
|
|
||||||
EpicycleId: epicycleId,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return &promptDto.ComposeMessagesRes{
|
switch req.BuildType {
|
||||||
|
case public.BuildTypePrompt:
|
||||||
|
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
||||||
|
case public.BuildTypeNode:
|
||||||
|
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
|
||||||
|
default:
|
||||||
|
return handleDefaultCase(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateUserForm 校验用户表单
|
||||||
|
func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
|
||||||
|
if len(req.UserForm) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("校验用户表单失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValid {
|
||||||
|
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||||
|
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
||||||
|
exceedTokens, availableWindow)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePromptBuild 处理提示词构建(BuildType=1)
|
||||||
|
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
|
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
|
||||||
|
history, err := GetHistoryMessages(ctx, req.SessionId)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
||||||
|
history = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var message *dto.MultiRoundResult
|
||||||
|
var taskRecord *entity.ComposeTask
|
||||||
|
for attempt := 0; attempt <= 0; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = saveComposeTask(ctx, taskID, req); err != nil {
|
||||||
|
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
taskRecord, err = waitForResult(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
message = parsePromptBuild(taskRecord, chatModel)
|
||||||
|
if message != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if message == nil {
|
||||||
|
return nil, errors.New("推理模型调用失败,请稍后再试")
|
||||||
|
}
|
||||||
|
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
|
SessionId: req.SessionId,
|
||||||
|
RequestContent: message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "创建会话记录失败: %v", err)
|
||||||
|
}
|
||||||
|
return &dto.ComposeMessagesRes{
|
||||||
Messages: message,
|
Messages: message,
|
||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||||
|
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
|
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveComposeTask(ctx, taskID, req); err != nil {
|
||||||
|
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
taskRecord, err := waitForResult(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("等待结果失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
message := parseNodeBuild(taskRecord)
|
||||||
|
|
||||||
|
return &dto.ComposeMessagesRes{
|
||||||
|
Messages: message,
|
||||||
|
EpicycleId: 0,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDefaultCase 处理默认情况
|
||||||
|
func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||||
|
epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
|
SessionId: req.SessionId,
|
||||||
|
Remark: req.Cause,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建会话记录失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dto.ComposeMessagesRes{
|
||||||
|
EpicycleId: epicycleId,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveComposeTask 保存组合任务
|
||||||
|
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
|
||||||
|
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||||
|
TaskId: taskID,
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
SkillName: req.SkillName,
|
||||||
|
RequestPayload: util.MustMarshal(req),
|
||||||
|
Status: public.ComposeStatusPending,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelMessage 获取模型信息
|
// GetModelMessage 获取模型信息
|
||||||
func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
userInfo, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||||
}
|
}
|
||||||
// 1. 获取当前用户的会话模型
|
|
||||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
chatModel, err := getChatModel(ctx, userInfo.UserName)
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
|
||||||
IsChatModel: 1,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if chatModel == nil {
|
|
||||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
|
||||||
}
|
|
||||||
// 2. 获取要构建的模型信息
|
|
||||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if aiModel == nil {
|
|
||||||
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
|
|
||||||
}
|
|
||||||
return chatModel, aiModel, nil
|
return chatModel, aiModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getChatModel 获取聊天模型
|
||||||
|
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
|
||||||
|
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatModel == nil {
|
||||||
|
return nil, errors.New("当前没有对话模型,请添加")
|
||||||
|
}
|
||||||
|
|
||||||
|
return chatModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAIModel 获取AI模型
|
||||||
|
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
|
||||||
|
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
||||||
|
ModelName: modelName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询AI模型失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if aiModel == nil {
|
||||||
|
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return aiModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
// callInferenceModel 调用推理模型
|
// callInferenceModel 调用推理模型
|
||||||
func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
|
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
|
||||||
// 构建推理模型请求
|
|
||||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
|
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建网关任务
|
|
||||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
||||||
@@ -186,96 +244,131 @@ func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq,
|
|||||||
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
||||||
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
||||||
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
||||||
|
|
||||||
deadline := time.Now().Add(timeout)
|
deadline := time.Now().Add(timeout)
|
||||||
|
ticker := time.NewTicker(pollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// ===================== 修复点 1:检查上下文是否取消 =====================
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// 请求已被取消,直接返回,不继续查库
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. 查数据库
|
|
||||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||||
TaskId: taskID,
|
TaskId: taskID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// ===================== 修复点 2:如果是上下文取消,直接返回 =====================
|
|
||||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if record != nil {
|
if record != nil {
|
||||||
switch record.Status {
|
if completed, result := checkTaskCompletion(record); completed {
|
||||||
case public.ComposeStatusSuccess:
|
return result, nil
|
||||||
return record, nil
|
|
||||||
case public.ComposeStatusFailed:
|
|
||||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
|
||||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 查网关状态
|
if err = syncGatewayTaskState(ctx, taskID, record); err != nil {
|
||||||
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err)
|
||||||
if err != nil {
|
|
||||||
// 网关不可达不终止,继续轮询
|
|
||||||
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
|
|
||||||
} else {
|
|
||||||
switch state {
|
|
||||||
case 2: // 网关成功
|
|
||||||
// 网关已成功,主动更新数据库
|
|
||||||
if record != nil {
|
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
|
||||||
Status: public.ComposeStatusSuccess,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case 3: // 网关失败
|
|
||||||
if record != nil {
|
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
|
||||||
Status: public.ComposeStatusFailed,
|
|
||||||
ErrorMessage: "model-gateway 任务执行失败",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 超时检查
|
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===================== 修复点3:sleep 也要监听 ctx 取消 =====================
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case <-time.After(pollInterval):
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkTaskCompletion 检查任务是否完成
|
||||||
|
func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) {
|
||||||
|
if record == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
switch record.Status {
|
||||||
|
case public.ComposeStatusSuccess:
|
||||||
|
return true, record
|
||||||
|
case public.ComposeStatusFailed:
|
||||||
|
errMsg := strings.TrimSpace(record.ErrorMessage)
|
||||||
|
if errMsg == "" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncGatewayTaskState 同步网关任务状态
|
||||||
|
func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error {
|
||||||
|
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("查询网关状态失败: %w", err)
|
||||||
|
}
|
||||||
|
switch state {
|
||||||
|
case 2:
|
||||||
|
return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "")
|
||||||
|
case 3:
|
||||||
|
updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败")
|
||||||
|
return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateTaskStatus 更新任务状态
|
||||||
|
func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error {
|
||||||
|
task := &entity.ComposeTask{
|
||||||
|
TaskId: taskID,
|
||||||
|
Status: status,
|
||||||
|
}
|
||||||
|
if errorMsg != "" {
|
||||||
|
task.ErrorMessage = errorMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := dao.ComposeTask.Update(ctx, task)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// parsePromptBuild 解析提示词构建结果(BuildType == 1)
|
// parsePromptBuild 解析提示词构建结果(BuildType == 1)
|
||||||
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
|
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult {
|
||||||
if taskRecord == nil {
|
if taskRecord == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
mapped := parseTaskMessages(taskRecord.Messages)
|
||||||
|
if mapped == nil {
|
||||||
|
return createDefaultResult(nil)
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 解析 Messages
|
contentField := getContentField(model)
|
||||||
|
contentStr, ok := mapped[contentField].(string)
|
||||||
|
if !ok || contentStr == "" {
|
||||||
|
return createDefaultResult(mapped)
|
||||||
|
}
|
||||||
|
|
||||||
|
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
|
||||||
|
return &dto.MultiRoundResult{
|
||||||
|
TotalRounds: len(roundsArray),
|
||||||
|
Rounds: roundsArray,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
|
||||||
|
return &dto.MultiRoundResult{
|
||||||
|
TotalRounds: 1,
|
||||||
|
Rounds: []any{singleRound},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return createDefaultResult(map[string]any{"content": contentStr})
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseTaskMessages 解析任务消息
|
||||||
|
func parseTaskMessages(messages any) map[string]any {
|
||||||
var mapped map[string]any
|
var mapped map[string]any
|
||||||
switch v := taskRecord.Messages.(type) {
|
|
||||||
|
switch v := messages.(type) {
|
||||||
case *gvar.Var:
|
case *gvar.Var:
|
||||||
if v != nil {
|
if v != nil {
|
||||||
json.Unmarshal([]byte(v.String()), &mapped)
|
json.Unmarshal([]byte(v.String()), &mapped)
|
||||||
@@ -289,115 +382,137 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
|||||||
json.Unmarshal(b, &mapped)
|
json.Unmarshal(b, &mapped)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 解析模型 ResponseMapping 获取 content 字段名
|
return mapped
|
||||||
contentField := "content" // 默认值
|
|
||||||
if model != nil {
|
|
||||||
var respMapping map[string]string
|
|
||||||
switch v := model.ResponseMapping.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
if v != nil {
|
|
||||||
json.Unmarshal([]byte(v.String()), &respMapping)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
json.Unmarshal([]byte(v), &respMapping)
|
|
||||||
case map[string]interface{}:
|
|
||||||
respMapping = make(map[string]string)
|
|
||||||
for k, val := range v {
|
|
||||||
if s, ok := val.(string); ok {
|
|
||||||
respMapping[k] = s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 从映射中找到 content 对应的字段名
|
|
||||||
for k, v := range respMapping {
|
|
||||||
if strings.Contains(v, "content") {
|
|
||||||
contentField = k
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 提取 content 的值
|
|
||||||
contentStr, ok := mapped[contentField].(string)
|
|
||||||
if !ok || contentStr == "" {
|
|
||||||
return mapped
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 解析 content 内的 JSON
|
|
||||||
var innerData map[string]any
|
|
||||||
json.Unmarshal([]byte(contentStr), &innerData)
|
|
||||||
|
|
||||||
return innerData
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseNodeBuild 解析节点构建结果(BuildType == 2)
|
// tryParseAsArray 尝试将字符串解析为数组
|
||||||
func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
|
func tryParseAsArray(contentStr string) []any {
|
||||||
if taskRecord == nil {
|
var roundsArray []any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var result map[string]any
|
return roundsArray
|
||||||
switch v := taskRecord.Messages.(type) {
|
}
|
||||||
|
|
||||||
|
// tryParseAsObject 尝试将字符串解析为对象
|
||||||
|
func tryParseAsObject(contentStr string) any {
|
||||||
|
var singleRound any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return singleRound
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDefaultResult 创建默认结果
|
||||||
|
func createDefaultResult(data any) *dto.MultiRoundResult {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
return &dto.MultiRoundResult{
|
||||||
|
TotalRounds: 1,
|
||||||
|
Rounds: []any{data},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getContentField 从模型 ResponseMapping 中获取 content 字段名
|
||||||
|
func getContentField(model *entity.AsynchModel) string {
|
||||||
|
if model == nil {
|
||||||
|
return "content"
|
||||||
|
}
|
||||||
|
|
||||||
|
respMapping := parseResponseMapping(model.ResponseMapping)
|
||||||
|
for k, v := range respMapping {
|
||||||
|
if strings.Contains(v, "content") {
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "content"
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponseMapping 解析响应映射
|
||||||
|
func parseResponseMapping(mapping any) map[string]string {
|
||||||
|
result := make(map[string]string)
|
||||||
|
|
||||||
|
switch v := mapping.(type) {
|
||||||
case *gvar.Var:
|
case *gvar.Var:
|
||||||
if v != nil {
|
if v != nil {
|
||||||
json.Unmarshal([]byte(v.String()), &result)
|
json.Unmarshal([]byte(v.String()), &result)
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
json.Unmarshal([]byte(v), &result)
|
json.Unmarshal([]byte(v), &result)
|
||||||
case map[string]any:
|
case map[string]interface{}:
|
||||||
result = v
|
for k, val := range v {
|
||||||
default:
|
if s, ok := val.(string); ok {
|
||||||
b, _ := json.Marshal(v)
|
result[k] = s
|
||||||
json.Unmarshal(b, &result)
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseNodeBuild 解析节点构建结果(BuildType == 2)
|
||||||
|
func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||||
|
if taskRecord == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseTaskMessages(taskRecord.Messages)
|
||||||
|
if result == nil {
|
||||||
|
result = make(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dto.MultiRoundResult{
|
||||||
|
TotalRounds: 1,
|
||||||
|
Rounds: []any{result},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Callback 回调处理
|
// Callback 回调处理
|
||||||
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
|
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||||
|
|
||||||
// ============ 先查任务是否存在 ============
|
|
||||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
if task == nil {
|
if task == nil {
|
||||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||||
}
|
}
|
||||||
// ============ 根据状态区分处理 ============
|
|
||||||
if req.State == 3 {
|
if req.State == 3 {
|
||||||
// 失败:直接更新状态
|
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
}
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: public.ComposeStatusFailed,
|
return handleCallbackSuccess(ctx, req)
|
||||||
ErrorMessage: req.ErrorMsg,
|
}
|
||||||
})
|
|
||||||
return err
|
// handleCallbackFailure 处理回调失败
|
||||||
}
|
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
|
||||||
// ======================================
|
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
// 成功:解析模型输出
|
TaskId: taskID,
|
||||||
result, err := util.ParseOutput(req.Text)
|
Status: public.ComposeStatusFailed,
|
||||||
if err != nil {
|
ErrorMessage: errorMsg,
|
||||||
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
})
|
||||||
TaskId: req.TaskId,
|
return err
|
||||||
Status: public.ComposeStatusFailed,
|
}
|
||||||
ErrorMessage: req.ErrorMsg,
|
|
||||||
})
|
// handleCallbackSuccess 处理回调成功
|
||||||
if updateErr != nil {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
|
||||||
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
|
result, err := util.ParseOutput(req.Text)
|
||||||
}
|
if err != nil {
|
||||||
return err
|
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||||
|
return fmt.Errorf("解析模型输出失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ result 可能为 nil ============
|
|
||||||
var messages any
|
var messages any
|
||||||
if result != nil {
|
if result != nil {
|
||||||
messages = result
|
messages = result
|
||||||
}
|
}
|
||||||
// =======================================
|
|
||||||
|
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
@@ -407,34 +522,43 @@ func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetComposeTask 查询任务结果
|
// GetComposeTask 查询任务结果
|
||||||
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
|
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||||
TaskId: taskID,
|
TaskId: taskID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
if record == nil {
|
if record == nil {
|
||||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果 Messages 是字符串,反序列化为 JSON 数组
|
messages := parseMessagesForResponse(record.Messages)
|
||||||
messages := record.Messages
|
|
||||||
if str, ok := messages.(string); ok && str != "" {
|
|
||||||
var parsed any
|
|
||||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
|
||||||
messages = parsed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &promptDto.GetComposeTaskRes{
|
return &dto.GetComposeTaskRes{
|
||||||
TaskId: record.TaskId,
|
TaskId: record.TaskId,
|
||||||
Status: record.Status,
|
Status: record.Status,
|
||||||
ErrorMessage: record.ErrorMessage,
|
ErrorMessage: record.ErrorMessage,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseMessagesForResponse 解析用于响应的消息
|
||||||
|
func parseMessagesForResponse(messages any) any {
|
||||||
|
str, ok := messages.(string)
|
||||||
|
if !ok || str == "" {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed any
|
||||||
|
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,10 +10,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/service/gateway"
|
"prompts-core/service/gateway"
|
||||||
|
)
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
const (
|
||||||
|
bytesPerKB = 1024
|
||||||
|
bytesPerMB = 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
|
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
|
||||||
@@ -24,51 +29,49 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{
|
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
||||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rawURL := range urls {
|
for _, rawURL := range urls {
|
||||||
url := util.SanitizeURL(rawURL)
|
url := util.SanitizeURL(rawURL)
|
||||||
if url == "" {
|
if url == "" || util.IsBannedExtension(url) {
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if util.IsBannedExtension(url) {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if util.IsZipExtension(url) {
|
if util.IsZipExtension(url) {
|
||||||
zipTexts := fetchZipFileTexts(ctx, client, url)
|
mergeMap(result, fetchZipFileTexts(ctx, client, url))
|
||||||
for k, v := range zipTexts {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
text, err := fetchFileContent(ctx, client, url)
|
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
||||||
if err != nil {
|
result[url] = text
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
text = util.CleanSymbols(text)
|
|
||||||
result[url] = text
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeMap 合并 map
|
||||||
|
func mergeMap(dst, src map[string]string) {
|
||||||
|
for k, v := range src {
|
||||||
|
dst[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchAndCleanFileContent 获取并清理文件内容
|
||||||
|
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
|
||||||
|
text, err := fetchFileContent(ctx, client, url)
|
||||||
|
if err != nil || text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return util.CleanSymbols(text)
|
||||||
|
}
|
||||||
|
|
||||||
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
|
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
|
||||||
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
||||||
result := make(map[string]string)
|
result := make(map[string]string)
|
||||||
|
|
||||||
zipBytes, err := downloadFile(client, url,
|
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||||
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
zipBytes, err := downloadFile(client, url, maxSize)
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
@@ -78,61 +81,61 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
|
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
|
||||||
|
|
||||||
for _, file := range reader.File {
|
for _, file := range reader.File {
|
||||||
if file.FileInfo().IsDir() {
|
if shouldSkipZipEntry(file.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := file.Name
|
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
|
||||||
|
result[url+"::"+file.Name] = text
|
||||||
if util.IsBannedExtension(fileName) {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if util.IsZipExtension(fileName) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := file.Open()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
|
||||||
rc.Close()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
contentType := http.DetectContentType(content)
|
|
||||||
if !util.IsReadableContentType(contentType) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
text := util.CleanSymbols(string(content))
|
|
||||||
if text == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
key := url + "::" + fileName
|
|
||||||
result[key] = text
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
|
||||||
|
func shouldSkipZipEntry(fileName string) bool {
|
||||||
|
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractZipEntryContent 提取 zip 条目内容
|
||||||
|
func extractZipEntryContent(file *zip.File, maxSize int64) string {
|
||||||
|
rc, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if !util.IsReadableContentType(http.DetectContentType(content)) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
text := util.CleanSymbols(string(content))
|
||||||
|
if text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
// downloadFile 下载文件,限制最大大小
|
// downloadFile 下载文件,限制最大大小
|
||||||
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("执行请求失败: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@@ -140,19 +143,24 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error
|
|||||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchFileContent 获取单个文本文件内容
|
// fetchFileContent 获取单个文本文件内容
|
||||||
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("执行请求失败: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@@ -162,16 +170,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
|||||||
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
if !util.IsReadableContentType(contentType) {
|
if !util.IsReadableContentType(contentType) {
|
||||||
return "", fmt.Errorf("unreadable content-type: %s", contentType)
|
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(
|
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
|
||||||
io.LimitReader(resp.Body,
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||||
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.TrimSpace(string(body)), nil
|
return strings.TrimSpace(string(body)), nil
|
||||||
@@ -186,27 +191,26 @@ func SkillMdContent(ctx context.Context, skillName string) string {
|
|||||||
|
|
||||||
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
||||||
|
|
||||||
client := &http.Client{
|
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
|
||||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
|
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||||
}
|
|
||||||
|
|
||||||
zipBytes, err := downloadFile(client, fullUrl,
|
zipBytes, err := downloadFile(client, fullUrl, maxSize)
|
||||||
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||||
if err != nil {
|
if err != nil || len(mdContents) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(mdContents) == 0 {
|
return buildSkillMarkdown(skillResp, mdContents)
|
||||||
return ""
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// buildSkillMarkdown 构建技能 Markdown 内容
|
||||||
|
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
|
||||||
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
||||||
if skillResp.Description != "" {
|
if skillResp.Description != "" {
|
||||||
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
||||||
@@ -227,35 +231,53 @@ func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, er
|
|||||||
|
|
||||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
|
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
|
||||||
|
|
||||||
for _, file := range reader.File {
|
for _, file := range reader.File {
|
||||||
if file.FileInfo().IsDir() {
|
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
|
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
|
||||||
continue
|
result[file.Name] = content
|
||||||
}
|
|
||||||
|
|
||||||
rc, err := file.Open()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
|
||||||
rc.Close()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(content) > 0 {
|
|
||||||
result[file.Name] = strings.TrimSpace(string(content))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isMarkdownFile 判断是否为 Markdown 文件
|
||||||
|
func isMarkdownFile(fileName string) bool {
|
||||||
|
return strings.HasSuffix(strings.ToLower(fileName), ".md")
|
||||||
|
}
|
||||||
|
|
||||||
|
// readMarkdownFileContent 读取 Markdown 文件内容
|
||||||
|
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
|
||||||
|
rc, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(content) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
// createHTTPClient 创建 HTTP 客户端
|
||||||
|
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
|
||||||
|
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
|
||||||
|
return &http.Client{
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
75
service/prompt/prompt_files_handle_service.markdown
Normal file
75
service/prompt/prompt_files_handle_service.markdown
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Prompts-Core(提示词核心服务)
|
||||||
|
|
||||||
|
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 项目简介
|
||||||
|
|
||||||
|
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
|
||||||
|
|
||||||
|
### 核心价值
|
||||||
|
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
|
||||||
|
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
|
||||||
|
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
|
||||||
|
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
|
||||||
|
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心功能
|
||||||
|
|
||||||
|
### 1. 提示词构建引擎
|
||||||
|
|
||||||
|
#### 多模态支持
|
||||||
|
| 类型 | 说明 | 适用场景 |
|
||||||
|
|------|------|----------|
|
||||||
|
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
|
||||||
|
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
|
||||||
|
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
|
||||||
|
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
|
||||||
|
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
|
||||||
|
|
||||||
|
#### 构建模式
|
||||||
|
- **BuildType 1(提示词构建)**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
|
||||||
|
- **BuildType 2(节点构建)**:工作流路由决策,根据上下文选择节点 ID
|
||||||
|
|
||||||
|
#### 分批处理
|
||||||
|
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
|
||||||
|
|
||||||
|
### 2. 会话管理系统
|
||||||
|
|
||||||
|
- **双层存储**:Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
|
||||||
|
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
|
||||||
|
|
||||||
|
### 3. 协议适配器
|
||||||
|
|
||||||
|
通过配置动态支持多种模型协议:
|
||||||
|
- 角色映射:system/user/assistant → 目标协议角色
|
||||||
|
- 内容字段映射:content → parts.text 等
|
||||||
|
- 消息顺序控制:灵活配置拼接顺序
|
||||||
|
- 请求模板渲染:支持占位符替换
|
||||||
|
|
||||||
|
### 4. 任务调度
|
||||||
|
|
||||||
|
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
|
||||||
|
- **重试机制**:可配置最大重试次数(默认 3 次)
|
||||||
|
- **超时保护**:默认 300 秒超时
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术架构
|
||||||
|
|
||||||
|
### 技术栈
|
||||||
|
|
||||||
|
| 组件 | 版本 | 用途 |
|
||||||
|
|------|------|------|
|
||||||
|
| Go | 1.26.0 | 编程语言 |
|
||||||
|
| GoFrame | v2.10.0 | Web 框架 |
|
||||||
|
| PostgreSQL | - | 关系型数据库 |
|
||||||
|
| Redis | - | 缓存与会话存储 |
|
||||||
|
| Consul | - | 服务注册与发现 |
|
||||||
|
| Jaeger | - | 分布式链路追踪 |
|
||||||
|
|
||||||
|
### 架构图
|
||||||
|
|
||||||
@@ -20,11 +20,27 @@ type PromptIR struct {
|
|||||||
|
|
||||||
// Segment 消息片段
|
// Segment 消息片段
|
||||||
type Segment struct {
|
type Segment struct {
|
||||||
Type string `json:"type"` // text/image
|
Type string `json:"type"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||||
|
type ProviderProtocol struct {
|
||||||
|
TargetField string `json:"target_field"`
|
||||||
|
MergeOrder []string `json:"merge_order"`
|
||||||
|
RoleMapping map[string]string `json:"role_mapping"`
|
||||||
|
ContentMapping ContentMapping `json:"content_mapping"`
|
||||||
|
RequestTemplate map[string]any `json:"request_template"`
|
||||||
|
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContentMapping 内容字段映射
|
||||||
|
type ContentMapping struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Field string `json:"field"`
|
||||||
|
}
|
||||||
|
|
||||||
// NewPromptIR 创建空 PromptIR
|
// NewPromptIR 创建空 PromptIR
|
||||||
func NewPromptIR() *PromptIR {
|
func NewPromptIR() *PromptIR {
|
||||||
return &PromptIR{
|
return &PromptIR{
|
||||||
@@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||||
|
func (ir *PromptIR) String() string {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
for _, seg := range ir.System {
|
||||||
|
builder.WriteString("System: ")
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range ir.History {
|
||||||
|
builder.WriteString(seg.Role)
|
||||||
|
builder.WriteString(": ")
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range ir.User {
|
||||||
|
builder.WriteString("User: ")
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||||
|
func (ir *PromptIR) GetTotalContent() string {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
for _, seg := range ir.System {
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range ir.History {
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, seg := range ir.User {
|
||||||
|
builder.WriteString(seg.Content)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
// AddSystem 添加系统提示
|
// AddSystem 添加系统提示
|
||||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
@@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
|||||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
func (ir *PromptIR) ToMessages() []map[string]any {
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
// 1. 系统消息
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
messages = append(messages, map[string]any{
|
messages = append(messages, map[string]any{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 历史消息
|
|
||||||
for _, seg := range ir.History {
|
for _, seg := range ir.History {
|
||||||
messages = append(messages, map[string]any{
|
messages = append(messages, map[string]any{
|
||||||
"role": seg.Role,
|
"role": seg.Role,
|
||||||
@@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 用户消息
|
|
||||||
for _, seg := range ir.User {
|
for _, seg := range ir.User {
|
||||||
messages = append(messages, map[string]any{
|
messages = append(messages, map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": seg.Content,
|
"content": seg.Content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,74 +159,35 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
|||||||
if err != nil || entity == nil {
|
if err != nil || entity == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
|
fmt.Println("entity打印", entity)
|
||||||
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
|
|
||||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
|
||||||
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
|
|
||||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
|
||||||
return parseProtocol(entity), nil
|
return parseProtocol(entity), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||||
p := &ProviderProtocol{
|
p := &ProviderProtocol{
|
||||||
TargetField: e.TargetField,
|
TargetField: e.TargetField,
|
||||||
|
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||||
}
|
}
|
||||||
|
|
||||||
// MergeOrder: any → []string
|
// 使用通用解析方法处理各个字段
|
||||||
if e.MergeOrder != nil {
|
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
||||||
b, _ := json.Marshal(e.MergeOrder)
|
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||||
json.Unmarshal(b, &p.MergeOrder)
|
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||||
}
|
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||||
|
|
||||||
// RoleMapping: any → map[string]string
|
|
||||||
if e.RoleMapping != nil {
|
|
||||||
b, _ := json.Marshal(e.RoleMapping)
|
|
||||||
json.Unmarshal(b, &p.RoleMapping)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContentMapping: any → ContentMapping
|
|
||||||
if e.ContentMapping != nil {
|
|
||||||
b, _ := json.Marshal(e.ContentMapping)
|
|
||||||
json.Unmarshal(b, &p.ContentMapping)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestTemplate: any → map[string]any
|
|
||||||
if e.RequestTemplate != nil {
|
|
||||||
b, _ := json.Marshal(e.RequestTemplate)
|
|
||||||
json.Unmarshal(b, &p.RequestTemplate)
|
|
||||||
}
|
|
||||||
fmt.Printf("parseProtocol: %+v\n", p)
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
|
||||||
type ProviderProtocol struct {
|
|
||||||
TargetField string `json:"target_field"`
|
|
||||||
MergeOrder []string `json:"merge_order"`
|
|
||||||
RoleMapping map[string]string `json:"role_mapping"`
|
|
||||||
ContentMapping ContentMapping `json:"content_mapping"`
|
|
||||||
RequestTemplate map[string]any `json:"request_template"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ContentMapping 内容字段映射
|
|
||||||
type ContentMapping struct {
|
|
||||||
Type string `json:"type"` // direct/parts
|
|
||||||
Field string `json:"field"` // content/text
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
||||||
if ir == nil || p == nil {
|
if ir == nil || p == nil {
|
||||||
return nil, fmt.Errorf("ir and protocol are required")
|
return nil, fmt.Errorf("ir and protocol are required")
|
||||||
}
|
}
|
||||||
// 1. 按 merge_order 拼接消息
|
|
||||||
messages := mergeByOrder(ir, p.MergeOrder)
|
messages := mergeByOrder(ir, p.MergeOrder)
|
||||||
// 2. 角色映射
|
|
||||||
messages = mapRoles(messages, p.RoleMapping)
|
messages = mapRoles(messages, p.RoleMapping)
|
||||||
// 3. 内容字段映射
|
|
||||||
messages = mapContent(messages, p.ContentMapping)
|
messages = mapContent(messages, p.ContentMapping)
|
||||||
// 4. 按 target_field + request_template 构建请求体
|
|
||||||
return buildRequest(messages, p, chatModel), nil
|
return buildRequest(messages, p, chatModel), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
|||||||
if len(mapping) == 0 {
|
if len(mapping) == 0 {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, msg := range messages {
|
for i, msg := range messages {
|
||||||
role, ok := msg["role"].(string)
|
role, ok := msg["role"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if mapped, exists := mapping[role]; exists {
|
if mapped, exists := mapping[role]; exists {
|
||||||
messages[i]["role"] = mapped
|
messages[i]["role"] = mapped
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
|||||||
|
|
||||||
switch cm.Type {
|
switch cm.Type {
|
||||||
case "parts":
|
case "parts":
|
||||||
// Gemini 格式: {"parts": [{"text": "..."}]}
|
|
||||||
msg["parts"] = []map[string]any{
|
msg["parts"] = []map[string]any{
|
||||||
{cm.Field: content},
|
{cm.Field: content},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// direct: {"content": "..."}
|
|
||||||
msg[cm.Field] = content
|
msg[cm.Field] = content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
|
|||||||
if len(p.RequestTemplate) > 0 {
|
if len(p.RequestTemplate) > 0 {
|
||||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
p.TargetField: messages,
|
p.TargetField: messages,
|
||||||
}
|
}
|
||||||
@@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
|
|||||||
b, _ := json.Marshal(tmpl)
|
b, _ := json.Marshal(tmpl)
|
||||||
str := string(b)
|
str := string(b)
|
||||||
|
|
||||||
// 替换 {{model}}
|
|
||||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
||||||
// 替换 {{messages}}
|
|
||||||
msgBytes, _ := json.Marshal(messages)
|
msgBytes, _ := json.Marshal(messages)
|
||||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
json.Unmarshal([]byte(str), &result)
|
json.Unmarshal([]byte(str), &result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,15 +9,16 @@ import (
|
|||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ==================== Redis 操作 ====================
|
const (
|
||||||
|
redisKeyPrefix = "chat:session:%s"
|
||||||
|
)
|
||||||
|
|
||||||
// saveToRedis 保存会话数据到Redis
|
// saveToRedis 保存会话数据到Redis
|
||||||
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
||||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
key := formatRedisKey(sessionId)
|
||||||
|
|
||||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||||
expireTime := time.Duration(expireSeconds) * time.Second
|
|
||||||
|
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"sessionId": sessionId,
|
"sessionId": sessionId,
|
||||||
@@ -31,18 +32,29 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
|
|||||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
|
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatRedisKey 格式化Redis键
|
||||||
|
func formatRedisKey(sessionId string) string {
|
||||||
|
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRedisCommands 执行Redis命令
|
||||||
|
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||||
|
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||||
return fmt.Errorf("写入Redis失败: %w", err)
|
return fmt.Errorf("写入Redis失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
|
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
|
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,7 +63,7 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st
|
|||||||
|
|
||||||
// getFromRedis 从Redis获取会话历史
|
// getFromRedis 从Redis获取会话历史
|
||||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
key := formatRedisKey(sessionId)
|
||||||
|
|
||||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -62,8 +74,17 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
|
|||||||
return []map[string]any{}, nil
|
return []map[string]any{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessions := parseRedisSessions(ctx, result.Strings())
|
||||||
|
|
||||||
|
reverseSlice(sessions)
|
||||||
|
|
||||||
|
return sessions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRedisSessions 解析Redis会话数据
|
||||||
|
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
||||||
var sessions []map[string]any
|
var sessions []map[string]any
|
||||||
values := result.Strings()
|
|
||||||
for _, str := range values {
|
for _, str := range values {
|
||||||
var data map[string]any
|
var data map[string]any
|
||||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||||
@@ -73,12 +94,14 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro
|
|||||||
sessions = append(sessions, data)
|
sessions = append(sessions, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 反转(Redis 最新在前 → 时间正序)
|
return sessions
|
||||||
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
|
}
|
||||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessions, nil
|
// reverseSlice 反转切片
|
||||||
|
func reverseSlice(s []map[string]any) {
|
||||||
|
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
s[i], s[j] = s[j], s[i]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||||
@@ -92,23 +115,31 @@ func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map
|
|||||||
return []map[string]any{}, nil
|
return []map[string]any{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return flattenHistoryMessages(historyData), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// flattenHistoryMessages 扁平化历史消息
|
||||||
|
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
for _, round := range historyData {
|
for _, round := range historyData {
|
||||||
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
|
appendMessagesFromField(round, "requestContent", &messages)
|
||||||
for _, m := range reqMsgs {
|
appendMessagesFromField(round, "responseContent", &messages)
|
||||||
if msg, ok := m.(map[string]interface{}); ok {
|
|
||||||
messages = append(messages, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
|
|
||||||
for _, m := range respMsgs {
|
|
||||||
if msg, ok := m.(map[string]interface{}); ok {
|
|
||||||
messages = append(messages, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages, nil
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendMessagesFromField 从指定字段追加消息
|
||||||
|
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
||||||
|
msgs, ok := data[field].([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range msgs {
|
||||||
|
if msg, ok := m.(map[string]interface{}); ok {
|
||||||
|
*messages = append(*messages, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,112 +3,164 @@ package prompt
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
sessionDao "prompts-core/dao"
|
|
||||||
"prompts-core/model/entity"
|
|
||||||
|
|
||||||
"prompts-core/common/util"
|
|
||||||
sessionDto "prompts-core/model/dto/prompt"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
"gitea.com/red-future/common/beans"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
|
||||||
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/dao"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
"prompts-core/model/entity"
|
||||||
)
|
)
|
||||||
|
|
||||||
func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
|
// SessionCallback 会话回调
|
||||||
// 1. 解析AI返回的文本
|
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||||
result, err := util.ParseOutput(req.Text)
|
result, err := util.ParseOutput(req.Text)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||||
return nil, err
|
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 更新数据库
|
|
||||||
result["role"] = "assistant"
|
result["role"] = "assistant"
|
||||||
_, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
|
||||||
ResponseContent: result,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 获取当前轮次完整数据
|
session, err := getSessionById(ctx, req.EpicycleId)
|
||||||
session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 转换 json 并存入 Redis
|
if err := saveSessionToRedis(ctx, session); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||||
|
|
||||||
if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
|
||||||
session.SessionId, session.Id, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||||
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
||||||
return &sessionDto.SessionCallbackRes{}, nil
|
|
||||||
|
return &dto.SessionCallbackRes{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSessionResponse 更新会话响应
|
||||||
|
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
|
||||||
|
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||||
|
ResponseContent: response,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
|
||||||
|
return fmt.Errorf("更新数据库失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSessionById 根据ID获取会话
|
||||||
|
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
|
||||||
|
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
|
||||||
|
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveSessionToRedis 保存会话到Redis
|
||||||
|
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||||
|
requestMessages := util.ConvertToMessages(session.RequestContent)
|
||||||
|
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
||||||
|
|
||||||
|
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
||||||
|
session.SessionId, session.Id, err)
|
||||||
|
return fmt.Errorf("Redis存储失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHistoryMessages 获取历史信息
|
// GetHistoryMessages 获取历史信息
|
||||||
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||||
|
|
||||||
// 1. 先从 Redis 拿
|
|
||||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
||||||
if err == nil && len(redisHistory) > 0 {
|
if err == nil && len(redisHistory) > 0 {
|
||||||
return redisHistory, nil
|
return redisHistory, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Redis 没有 → fallback DB
|
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
|
||||||
sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
|
}
|
||||||
|
|
||||||
|
// getHistoryFromDatabase 从数据库获取历史记录
|
||||||
|
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
||||||
|
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
}, 1, maxRounds)
|
}, 1, maxRounds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
messages := extractMessagesFromSessions(sessions)
|
||||||
|
|
||||||
|
cacheSessionsToRedis(ctx, sessions)
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractMessagesFromSessions 从会话列表中提取消息
|
||||||
|
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
// request
|
appendRequestMessages(session.RequestContent, &messages)
|
||||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
appendResponseMessages(session.ResponseContent, &messages)
|
||||||
for _, m := range reqMsgs {
|
|
||||||
role := gconv.String(m["role"])
|
|
||||||
if role == "user" || role == "assistant" {
|
|
||||||
messages = append(messages, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// response
|
|
||||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
|
||||||
for _, m := range respMsgs {
|
|
||||||
if m["role"] == nil {
|
|
||||||
m["role"] = "assistant"
|
|
||||||
}
|
|
||||||
messages = append(messages, m)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 回写 Redis
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendRequestMessages 追加请求消息
|
||||||
|
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
||||||
|
reqMsgs := util.ConvertToMessages(requestContent)
|
||||||
|
for _, m := range reqMsgs {
|
||||||
|
role := gconv.String(m["role"])
|
||||||
|
if role == "user" || role == "assistant" {
|
||||||
|
*messages = append(*messages, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendResponseMessages 追加响应消息
|
||||||
|
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
||||||
|
respMsgs := util.ConvertToMessages(responseContent)
|
||||||
|
for _, m := range respMsgs {
|
||||||
|
if m["role"] == nil {
|
||||||
|
m["role"] = "assistant"
|
||||||
|
}
|
||||||
|
*messages = append(*messages, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheSessionsToRedis 将会话缓存到Redis
|
||||||
|
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
||||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
||||||
|
|
||||||
for i := range respMsgs {
|
for i := range respMsgs {
|
||||||
if respMsgs[i]["role"] == nil {
|
if respMsgs[i]["role"] == nil {
|
||||||
respMsgs[i]["role"] = "assistant"
|
respMsgs[i]["role"] = "assistant"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||||
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return messages, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
135
service/prompt/prompt_user_form_batches.go
Normal file
135
service/prompt/prompt_user_form_batches.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package prompt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
"prompts-core/model/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
||||||
|
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
||||||
|
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
||||||
|
return req, 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||||
|
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
|
||||||
|
|
||||||
|
if len(batches) <= 1 {
|
||||||
|
return req, 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newUserForm := buildBatchedUserForm(batches)
|
||||||
|
|
||||||
|
newReq := *req
|
||||||
|
newReq.UserForm = newUserForm
|
||||||
|
|
||||||
|
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
|
||||||
|
len(req.UserForm), len(batches))
|
||||||
|
|
||||||
|
return &newReq, len(batches), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildBatchedUserForm 构建分批后的用户表单
|
||||||
|
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
|
||||||
|
newUserForm := make([]map[string]any, 0, len(batches))
|
||||||
|
|
||||||
|
for i, batch := range batches {
|
||||||
|
combinedText := combineBatchText(batch)
|
||||||
|
newUserForm = append(newUserForm, map[string]any{
|
||||||
|
"batch_index": i + 1,
|
||||||
|
"total_batches": len(batches),
|
||||||
|
"text": combinedText,
|
||||||
|
"item_count": len(batch),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return newUserForm
|
||||||
|
}
|
||||||
|
|
||||||
|
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
|
||||||
|
func combineBatchText(batch []map[string]any) string {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
for j, item := range batch {
|
||||||
|
itemText := getItemText(item)
|
||||||
|
if itemText == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if j > 0 {
|
||||||
|
builder.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
builder.WriteString(itemText)
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
|
||||||
|
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
|
||||||
|
if len(userForm) == 0 {
|
||||||
|
return [][]map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
batches := make([][]map[string]any, 0)
|
||||||
|
currentBatch := make([]map[string]any, 0)
|
||||||
|
currentTokens := 0
|
||||||
|
|
||||||
|
for i, item := range userForm {
|
||||||
|
itemText := getItemText(item)
|
||||||
|
itemTokens := util.CalculateTokens(itemText, tokenConfig)
|
||||||
|
|
||||||
|
// 单个元素超过窗口,单独成一批
|
||||||
|
if itemTokens > maxTokens {
|
||||||
|
if len(currentBatch) > 0 {
|
||||||
|
batches = append(batches, currentBatch)
|
||||||
|
currentBatch = make([]map[string]any, 0)
|
||||||
|
currentTokens = 0
|
||||||
|
}
|
||||||
|
batches = append(batches, []map[string]any{item})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 判断是否需要新开一批
|
||||||
|
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
|
||||||
|
batches = append(batches, currentBatch)
|
||||||
|
currentBatch = make([]map[string]any, 0)
|
||||||
|
currentTokens = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
currentBatch = append(currentBatch, item)
|
||||||
|
currentTokens += itemTokens
|
||||||
|
|
||||||
|
// 最后一批
|
||||||
|
if i == len(userForm)-1 && len(currentBatch) > 0 {
|
||||||
|
batches = append(batches, currentBatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return batches
|
||||||
|
}
|
||||||
|
|
||||||
|
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
|
||||||
|
func getItemText(item map[string]any) string {
|
||||||
|
if len(item) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
for key, value := range item {
|
||||||
|
// 跳过分批时添加的元数据字段
|
||||||
|
if key == "batch_index" || key == "total_batches" || key == "item_count" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, fmt.Sprintf("%v", value))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user