Compare commits
32 Commits
ce9137169f
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d52b631b9 | |||
| c22d578e1a | |||
| df26329836 | |||
| 40abf0f606 | |||
| b69e7386e2 | |||
| 1c1db7e30c | |||
| 78114f99c7 | |||
| 9410199fbe | |||
| 1f9a2b9b5f | |||
| e1461cf0f0 | |||
| aa7804656f | |||
| 5494a0c480 | |||
| ee6677c1f8 | |||
| de70d33115 | |||
| b2cad4cac2 | |||
| 05cf1b9828 | |||
| 3fa2896fc3 | |||
| c11a9ad5c8 | |||
| 0bbaddace0 | |||
| 1bcf8f6e10 | |||
| 55eb436639 | |||
| d74559ae74 | |||
| 2548ffc7ac | |||
| 855d5b9abe | |||
| 866b97d098 | |||
| 92092575bc | |||
| a34eb4ea61 | |||
| 15f5761000 | |||
| fee6528f93 | |||
| 35bc3bd6ec | |||
| c49144794d | |||
| 5f98e52b34 |
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
# 阶段1: 构建
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
ENV GO111MODULE=on
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
ENV CGO_ENABLED=0
|
||||
ENV GOTOOLCHAIN=auto
|
||||
WORKDIR /build
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN go mod download && go mod tidy
|
||||
|
||||
RUN go build -ldflags="-s -w" -o main ./main.go
|
||||
|
||||
|
||||
EXPOSE 3009
|
||||
|
||||
CMD ["./main"]
|
||||
72
README.md
72
README.md
@@ -1,30 +1,54 @@
|
||||
# prompts-core(提示词服务)[2026.5.12前,暂时弃置]
|
||||
# Prompts-Core 提示词核心服务
|
||||
## 项目简介
|
||||
Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。
|
||||
|
||||
## 1. 功能范围(当前阶段)
|
||||
- 仅做提示词配置的基础 CRUD(最小可用版本)
|
||||
- 表:`prompts_model_prompt`
|
||||
## 核心功能
|
||||
1. **提示词构建引擎**
|
||||
支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。
|
||||
2. **智能会话管理**
|
||||
基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。
|
||||
3. **多模型协议适配**
|
||||
动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。
|
||||
4. **文件与技能集成**
|
||||
自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。
|
||||
5. **异步任务调度**
|
||||
支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。
|
||||
|
||||
## 2. 接口
|
||||
> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。
|
||||
## 技术架构
|
||||
- 开发语言:Go 1.26.0
|
||||
- Web 框架:GoFrame v2.10.0
|
||||
- 核心存储:Redis(会话缓存)
|
||||
- 服务组件:Consul(服务注册)、Jaeger(链路追踪)
|
||||
- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型
|
||||
|
||||
- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema`,`modelName` 作为实际调用的网关模型;结合前端 `form(role/value)` 与 `userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages`
|
||||
- `GET /composeMessagesCallback/prompts-core`:`model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成)
|
||||
- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果
|
||||
- `POST /createPrompt`:创建(默认启用)
|
||||
- `PUT /updatePrompt`:更新
|
||||
- `DELETE /deletePrompt`:删除
|
||||
- `GET /getPrompt`:详情
|
||||
- `POST /listPrompt`:列表分页
|
||||
## 快速开始
|
||||
### 环境要求
|
||||
Go 1.26+、Redis、已部署模型网关服务
|
||||
|
||||
## 3. 数据库初始化
|
||||
执行根目录 `update.sql`。
|
||||
### 启动步骤
|
||||
1. 克隆项目代码
|
||||
2. 完成项目配置文件修改
|
||||
3. 执行命令启动服务:
|
||||
```bash
|
||||
go run main.go
|
||||
```
|
||||
|
||||
## 4. 运行配置
|
||||
配置文件:`config.yml`
|
||||
## API 接口
|
||||
### 基础信息
|
||||
- 服务地址:`http://{host}:3009`
|
||||
- 请求类型:`application/json`
|
||||
- 认证方式:请求头携带 `Authorization`、`X-User`
|
||||
|
||||
### 新增说明
|
||||
- `prompts_model_prompt` 去除了 `limit_length`
|
||||
- 新增 `response_json_schema`
|
||||
- 新增任务记录表 `prompts_compose_task`
|
||||
- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback`
|
||||
- `model-gateway` 实际回调地址为:`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core`
|
||||
### 核心接口
|
||||
1. **提示词拼接接口**
|
||||
- 地址:`POST /composeMessages`
|
||||
- 功能:构建提示词并调用模型服务,同步返回结果
|
||||
2. **任务状态查询**
|
||||
- 地址:`GET /getComposeTask`
|
||||
- 功能:根据任务 ID 查询处理状态与结果
|
||||
3. **任务回调接口**
|
||||
- 地址:`GET /composeMessagesCallback/prompts-core`
|
||||
- 功能:接收模型服务处理完成回调
|
||||
4. **会话同步接口**
|
||||
- 地址:`POST /sessionCallback`
|
||||
- 功能:同步更新会话上下文历史
|
||||
34
common/util/config.go
Normal file
34
common/util/config.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// GetServerName 获取服务名称
|
||||
func GetServerName(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||
}
|
||||
|
||||
// GetModelPrompt 获取请求模型的提示词
|
||||
func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
key := "modelPrompts.types." + gconv.String(modelType)
|
||||
return g.Cfg().MustGet(ctx, key, "").String()
|
||||
}
|
||||
|
||||
// GetBuildPrompt 获取节点构建提示词
|
||||
func GetBuildPrompt(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||
}
|
||||
|
||||
// GetMaxRounds 获取最大轮数配置
|
||||
func GetMaxRounds(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
}
|
||||
|
||||
// GetExpireMinutes 获取过期时间配置
|
||||
func GetExpireMinutes(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
|
||||
}
|
||||
96
common/util/files.go
Normal file
96
common/util/files.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
|
||||
AllowedMIMEPrefixes = []string{
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"application/toml",
|
||||
"application/x-httpd-php",
|
||||
"application/x-sh",
|
||||
"application/x-python",
|
||||
"application/x-perl",
|
||||
"application/x-ruby",
|
||||
}
|
||||
|
||||
// BannedExtensions 禁止的文件扩展名
|
||||
BannedExtensions = map[string]bool{
|
||||
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
||||
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
||||
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
||||
".wma": true, ".m4a": true,
|
||||
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
|
||||
".flv": true, ".webm": true,
|
||||
".tar": true, ".gz": true, ".rar": true, ".7z": true,
|
||||
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
|
||||
".class": true, ".pyc": true,
|
||||
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
||||
".ppt": true, ".pptx": true,
|
||||
}
|
||||
|
||||
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||
multiNewlines = regexp.MustCompile(`\n{3,}`)
|
||||
)
|
||||
|
||||
// SanitizeURL 清洗 URL 字符串
|
||||
func SanitizeURL(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
s = strings.Trim(s, "`\"")
|
||||
return s
|
||||
}
|
||||
|
||||
// CleanSymbols 清洗文本中的控制字符和多余空行
|
||||
func CleanSymbols(text string) string {
|
||||
text = symbolCleaner.ReplaceAllString(text, "")
|
||||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||
text = strings.ReplaceAll(text, "\r", "\n")
|
||||
text = multiNewlines.ReplaceAllString(text, "\n\n")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// IsBannedExtension 判断是否为禁止的文件扩展名
|
||||
func IsBannedExtension(url string) bool {
|
||||
ext := extractExtension(url)
|
||||
return BannedExtensions[ext]
|
||||
}
|
||||
|
||||
// IsZipExtension 判断是否为 zip 文件
|
||||
func IsZipExtension(url string) bool {
|
||||
ext := extractExtension(url)
|
||||
return ext == ".zip"
|
||||
}
|
||||
|
||||
// IsReadableContentType 判断是否为可读的文本类型
|
||||
func IsReadableContentType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
ct := strings.ToLower(contentType)
|
||||
for _, prefix := range AllowedMIMEPrefixes {
|
||||
if strings.HasPrefix(ct, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// extractExtension 提取文件扩展名并清理查询参数
|
||||
func extractExtension(url string) string {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return ext
|
||||
}
|
||||
67
common/util/headers.go
Normal file
67
common/util/headers.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||
func AsyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
|
||||
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
setHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||
|
||||
fallbackToRequestHeaders(headers, ctx)
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// setHeaderFromContext 从上下文中设置 header
|
||||
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||
headers[headerKey] = value
|
||||
}
|
||||
}
|
||||
|
||||
// fallbackToRequestHeaders 从请求头中获取作为兜底
|
||||
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
81
common/util/json.go
Normal file
81
common/util/json.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
consult := gconv.Interfaces(req["consult"])
|
||||
if len(consult) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||
if targetPath == "" || len(templates) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
// rounds 路径修正
|
||||
if !msgJson.Get("rounds.0").IsNil() {
|
||||
targetPath = "rounds.0." + targetPath
|
||||
}
|
||||
|
||||
// 遍历追加
|
||||
for _, item := range consult {
|
||||
itemJson := gjson.New(item)
|
||||
itemType := itemJson.Get("type").String()
|
||||
tmpl := gconv.Map(templates[itemType])
|
||||
if itemType == "" || len(tmpl) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||
if attachment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := len(msgJson.Get(targetPath).Array())
|
||||
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
||||
}
|
||||
|
||||
return msgJson.Map()
|
||||
}
|
||||
|
||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||
typ := gconv.String(tmpl["type"])
|
||||
if typ == "" || url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
body := gconv.Map(tmpl["body"])
|
||||
fillEmptyInPlace(body, url)
|
||||
|
||||
return map[string]any{
|
||||
"type": typ,
|
||||
typ: body,
|
||||
}
|
||||
}
|
||||
|
||||
func fillEmptyInPlace(m map[string]any, value string) {
|
||||
for k, v := range m {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
if vv == "" {
|
||||
m[k] = value
|
||||
}
|
||||
case map[string]any:
|
||||
fillEmptyInPlace(vv, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
57
common/util/mapping.go
Normal file
57
common/util/mapping.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ReverseMap 映射 payload 到 mapping
|
||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
jsonObj := gjson.New("{}")
|
||||
for path, defaultValue := range mapping {
|
||||
val := gjson.New(payload).Get(path)
|
||||
if !val.IsNil() {
|
||||
_ = jsonObj.Set(path, val.Val())
|
||||
} else if defaultValue != nil {
|
||||
_ = jsonObj.Set(path, defaultValue)
|
||||
}
|
||||
}
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||
func ExtractUserText(messages map[string]any) map[string]any {
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
msgs := msgJson.Get("rounds.0.messages")
|
||||
if msgs.IsNil() {
|
||||
msgs = msgJson.Get("messages")
|
||||
}
|
||||
var texts []string
|
||||
for _, m := range msgs.Array() {
|
||||
msg := gjson.New(m)
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content").Val()
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
texts = append(texts, c)
|
||||
case []any:
|
||||
for _, item := range c {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t := gconv.String(m["text"]); t != "" {
|
||||
texts = append(texts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": strings.Join(texts, "\n"),
|
||||
}
|
||||
}
|
||||
229
common/util/token.go
Normal file
229
common/util/token.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
var (
|
||||
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
|
||||
punctRegex = regexp.MustCompile(`[[:punct:]]`)
|
||||
)
|
||||
|
||||
// TokenConfig Token计算配置
|
||||
type TokenConfig struct {
|
||||
ZhRatio float64 `json:"zh_ratio"`
|
||||
EnRatio float64 `json:"en_ratio"`
|
||||
SpaceRatio float64 `json:"space_ratio"`
|
||||
PunctuationRatio float64 `json:"punctuation_ratio"`
|
||||
MaxWindowSize int `json:"max_window_size"`
|
||||
ReserveRatio float64 `json:"reserve_ratio"`
|
||||
MinReserve int `json:"min_reserve"`
|
||||
}
|
||||
|
||||
// CalculateTokens 计算文本token数
|
||||
func CalculateTokens(text string, tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
zhCount := countChineseChars(text)
|
||||
enCount := countEnglishWords(text)
|
||||
spaceCount := strings.Count(text, " ")
|
||||
punctCount := countPunctuation(text)
|
||||
|
||||
totalTokens := int(
|
||||
float64(zhCount)*config.ZhRatio +
|
||||
float64(enCount)*config.EnRatio +
|
||||
float64(spaceCount)*config.SpaceRatio +
|
||||
float64(punctCount)*config.PunctuationRatio,
|
||||
)
|
||||
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// CountToken 计算token是否超出窗口限制
|
||||
// 返回: true - 未超出(可用), false - 已超出(不可用)
|
||||
func CountToken(text string, tokenConfig any) bool {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
estimatedTokens := CalculateTokens(text, tokenConfig)
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
return estimatedTokens <= availableWindow
|
||||
}
|
||||
|
||||
// GetAvailableWindow 获取可用窗口大小
|
||||
func GetAvailableWindow(tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 4096
|
||||
}
|
||||
|
||||
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
|
||||
reserve := reserveByRatio
|
||||
|
||||
if config.MinReserve > reserve {
|
||||
reserve = config.MinReserve
|
||||
}
|
||||
|
||||
available := config.MaxWindowSize - reserve
|
||||
if available < 0 {
|
||||
available = 0
|
||||
}
|
||||
|
||||
return available
|
||||
}
|
||||
|
||||
// GetMaxWindowSize 获取模型最大窗口大小
|
||||
func GetMaxWindowSize(tokenConfig any) int {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil {
|
||||
return 4096
|
||||
}
|
||||
|
||||
return config.MaxWindowSize
|
||||
}
|
||||
|
||||
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
|
||||
// 返回: isValid, exceedTokens, error
|
||||
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil || len(userForm) == 0 {
|
||||
return true, 0, nil
|
||||
}
|
||||
|
||||
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
if totalTokens > availableWindow {
|
||||
return false, totalTokens - availableWindow, nil
|
||||
}
|
||||
|
||||
return true, 0, nil
|
||||
}
|
||||
|
||||
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
|
||||
// 返回: 需要拆分的批次数, 每批的token数, 错误
|
||||
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
|
||||
config := parseConfig(tokenConfig)
|
||||
if config == nil || len(userForm) == 0 {
|
||||
return 1, nil, nil
|
||||
}
|
||||
|
||||
availableWindow := GetAvailableWindow(tokenConfig)
|
||||
|
||||
batches := 1
|
||||
currentTokens := 0
|
||||
batchTokens := make([]int, 0)
|
||||
|
||||
for _, item := range userForm {
|
||||
itemStr := fmt.Sprintf("%v", item)
|
||||
itemTokens := CalculateTokens(itemStr, tokenConfig)
|
||||
|
||||
if currentTokens+itemTokens > availableWindow {
|
||||
batchTokens = append(batchTokens, currentTokens)
|
||||
batches++
|
||||
currentTokens = itemTokens
|
||||
} else {
|
||||
currentTokens += itemTokens
|
||||
}
|
||||
}
|
||||
|
||||
if currentTokens > 0 {
|
||||
batchTokens = append(batchTokens, currentTokens)
|
||||
}
|
||||
|
||||
return batches, batchTokens, nil
|
||||
}
|
||||
|
||||
// parseConfig 解析配置
|
||||
func parseConfig(tokenConfig any) *TokenConfig {
|
||||
if tokenConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := tokenConfig.(type) {
|
||||
case *gvar.Var:
|
||||
return parseGVarConfig(v)
|
||||
case map[string]any:
|
||||
return parseMapConfig(v)
|
||||
case *TokenConfig:
|
||||
return v
|
||||
case TokenConfig:
|
||||
return &v
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// parseGVarConfig 解析 GVar 配置
|
||||
func parseGVarConfig(v *gvar.Var) *TokenConfig {
|
||||
if v.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapVal := v.Map()
|
||||
if mapVal == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &TokenConfig{}
|
||||
data, _ := json.Marshal(mapVal)
|
||||
json.Unmarshal(data, config)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// parseMapConfig 解析 Map 配置
|
||||
func parseMapConfig(v map[string]any) *TokenConfig {
|
||||
config := &TokenConfig{}
|
||||
data, _ := json.Marshal(v)
|
||||
json.Unmarshal(data, config)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// countChineseChars 统计中文字符数量
|
||||
func countChineseChars(text string) int {
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if unicode.Is(unicode.Han, r) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// countEnglishWords 统计英文单词数量
|
||||
func countEnglishWords(text string) int {
|
||||
return len(enWordRegex.FindAllString(text, -1))
|
||||
}
|
||||
|
||||
// countPunctuation 统计标点符号数量
|
||||
func countPunctuation(text string) int {
|
||||
return len(punctRegex.FindAllString(text, -1))
|
||||
}
|
||||
|
||||
// calculateUserFormTokens 计算 UserForm 总 token 数
|
||||
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
|
||||
totalTokens := 0
|
||||
for _, item := range userForm {
|
||||
itemStr := fmt.Sprintf("%v", item)
|
||||
totalTokens += CalculateTokens(itemStr, tokenConfig)
|
||||
}
|
||||
return totalTokens
|
||||
}
|
||||
104
config.yml
104
config.yml
@@ -1,5 +1,5 @@
|
||||
server:
|
||||
address: ":3005"
|
||||
address: ":3009"
|
||||
name: "prompts-core"
|
||||
workerId: 1 # 雪花算法 worker ID(用于 common/db/gfdb)
|
||||
|
||||
@@ -26,21 +26,41 @@ database:
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
model_gateway:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "model-gateway"
|
||||
prefix: ""
|
||||
role: "master"
|
||||
debug: true
|
||||
dryRun: false
|
||||
charset: "utf8"
|
||||
timezone: "Asia/Shanghai"
|
||||
maxIdle: 5
|
||||
maxOpen: 20
|
||||
maxLifetime: "30s"
|
||||
maxIdleConnTime: "30s"
|
||||
createdAt: "created_at"
|
||||
updatedAt: "updated_at"
|
||||
deletedAt: "deleted_at"
|
||||
timeMaintainDisabled: false
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: 116.204.74.41:6379
|
||||
address: 192.168.3.30:6379
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 116.204.74.41:8500
|
||||
address: 192.168.3.30:8500
|
||||
|
||||
jaeger:
|
||||
addr: 116.204.74.41:4318
|
||||
addr: 192.168.3.30:4318
|
||||
|
||||
task:
|
||||
waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒)
|
||||
pollIntervalMillis: 500 # 同步等待期间,轮询本地任务表 / 网关状态的时间间隔(毫秒)
|
||||
waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒)
|
||||
|
||||
session:
|
||||
maxRounds: 10 # 最大轮数
|
||||
@@ -61,76 +81,34 @@ promptsRetry:
|
||||
|
||||
modelPrompts:
|
||||
types:
|
||||
1: |
|
||||
100: |
|
||||
你是一个智能文字处理助手,专注于文本理解、文本创作、文本优化与语言表达任务,能够根据不同场景完成文章撰写、商业文案、报告总结、邮件通知、脚本创作、内容改写、信息提炼、语言翻译等多种文字处理工作,并能够理解上下文语义关系,保持内容逻辑完整、结构清晰、表达自然。
|
||||
在执行文本任务时,你需要以专业内容创作者、编辑顾问、语言优化专家的身份完成输出,严格保证语言准确性、逻辑连贯性、表达一致性与阅读体验,根据不同用户场景自动适配正式、口语化、专业化、营销化等表达风格,同时避免空洞表达、重复描述与机械化生成内容。
|
||||
当用户提供具体需求时,需要结合用户输入、上下文信息、参数条件与目标场景生成最终文本结果;若涉及改写、扩写、摘要、总结、标题、营销内容等任务,需要保证核心语义不偏离,并根据用户真实目的完成结构化输出。
|
||||
2: |
|
||||
200: |
|
||||
你是一个智能图片处理助手,专注于视觉内容生成、图像编辑、画面分析与风格控制任务,能够根据文字描述生成不同风格的图片内容,包括写实、插画、动漫、水彩、电影感、商业海报等多种视觉形式,并支持图片局部修改、风格迁移、画面扩展、背景处理与视觉增强等操作。
|
||||
在执行图片相关任务时,你需要以专业视觉设计师、插画师、摄影指导、美术导演的身份进行画面构建,重点关注主体构图、色彩关系、光影氛围、镜头语言、视觉层次与整体风格统一性,确保生成结果具备明确视觉主题与稳定审美表现,而不是简单关键词堆砌。
|
||||
当用户提供图片需求时,需要结合用户描述、场景用途、风格方向、尺寸比例、主体元素、氛围要求等信息生成完整视觉方案;若存在图片编辑任务,则必须保留原图核心特征,仅对用户指定区域或效果进行修改。
|
||||
3: |
|
||||
300: |
|
||||
你是一个智能音频处理助手,专注于语音生成、语音识别、音频分析与声音编辑任务,能够完成文字转语音、语音转文本、多语言识别、音频降噪、音色处理、混音剪辑、情绪识别与声音特征分析等多种音频相关工作,并能够根据不同场景匹配对应语音风格与声音表现形式。
|
||||
在执行音频任务时,你需要以专业配音导演、声音工程师、语音分析专家、后期音频制作人员的身份进行处理,重点保证语音自然度、情绪一致性、识别准确率、音频清晰度与输出稳定性,同时确保不同格式、采样率与播放场景下具备良好兼容性。
|
||||
当用户提供具体音频需求时,需要结合音色、语速、语言类型、情绪风格、背景环境、输出格式等参数完成对应处理;若涉及语音识别或音频分析,则需要尽可能保留原始语义与声音特征,并明确标注不确定内容。
|
||||
4: |
|
||||
400: |
|
||||
你是一个智能向量化处理助手,专注于文本向量化、语义检索、知识索引、相似度计算与语义聚类任务,能够将文本内容转换为高维语义向量,并基于向量相似度完成语义搜索、知识召回、内容聚类、文档匹配与知识库构建等处理流程。
|
||||
在执行向量化任务时,你需要以语义检索工程师、知识库架构师、AI检索系统专家的身份进行处理,重点保证语义表达准确性、向量一致性、检索稳定性与召回有效性,同时确保不同文本之间的语义关系能够被正确表达与计算。
|
||||
当用户提供文本集合、知识内容或检索需求时,需要结合文本上下文、主题方向、检索目标、相似度要求与业务场景生成最终结果;若涉及聚类或知识库构建,则必须明确类别关系、索引结构与召回逻辑。
|
||||
5: |
|
||||
500: |
|
||||
你是一个全模态智能处理助手,能够同时理解、分析与生成文本、图片、音频、视频等多种模态内容,并支持跨模态转换、多模态融合推理、联合内容生成与复杂场景交互,能够根据不同输入形式自动匹配最合理的处理策略与输出方式。
|
||||
在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。
|
||||
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
|
||||
|
||||
buildProject:
|
||||
types:
|
||||
1: |
|
||||
你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
||||
【强制规则】
|
||||
必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
||||
完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
||||
完整读取UserForm所有字段,不得忽略任何字段;
|
||||
如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
||||
理解全部语义后再输出,禁止断章取义;
|
||||
UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
||||
【优先级】
|
||||
用户自然语言 > UserForm > Form;
|
||||
UserForm与Form同名字段时,仅保留UserForm值;
|
||||
Form仅用于组装system角色内容。
|
||||
【表单处理】
|
||||
Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
||||
UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
||||
自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和(示例:10条文案各配5张图 → 总50张,parameters.n=50),用户没有相关数量必须默认1;
|
||||
图片尺寸为空时自动填充size=1024*1024。
|
||||
【结构铁律】
|
||||
严格沿用固定输出结构,不增删字段或修改层级;
|
||||
messages元素必须按结构返回;
|
||||
禁止将role对象转为字符串、禁止嵌套错乱;
|
||||
输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
||||
所有括号、引号必须成对闭合,保证JSON合法。
|
||||
【参数赋值】
|
||||
model固定沿用传入值;
|
||||
返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
||||
history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
||||
从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
||||
【输出要求】
|
||||
仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
||||
完整合UserForm全部字段语义到user描述;
|
||||
生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
||||
【输出结构】
|
||||
%s
|
||||
【字段映射】
|
||||
%s
|
||||
【完整输入信息】
|
||||
%s
|
||||
直接输出最终JSON:
|
||||
2: |
|
||||
你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。
|
||||
规则:
|
||||
1. 只允许从下面的可选节点ID列表中选择一个返回
|
||||
2. 不要返回任何多余文字、标点、解释、标题
|
||||
3. 只返回纯节点ID
|
||||
可选节点ID(ID: 节点描述):
|
||||
%s
|
||||
上下文内容:
|
||||
%s
|
||||
nodePrompts: |
|
||||
你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。
|
||||
规则:
|
||||
1. 只允许从下面的可选节点ID列表中选择一个返回
|
||||
2. 不要返回任何多余文字、标点、解释、标题
|
||||
3. 只返回纯节点ID
|
||||
可选节点ID(ID: 节点描述):
|
||||
%s
|
||||
上下文内容:
|
||||
%s
|
||||
|
||||
@@ -5,3 +5,13 @@ const (
|
||||
ComposeStatusSuccess = "success"
|
||||
ComposeStatusFailed = "failed"
|
||||
)
|
||||
|
||||
const (
|
||||
BuildTypePrompt = 1 //提示词构建
|
||||
BuildTypeNode = 2 //节点构建
|
||||
BuildTypeStruct = 3 //结构构建
|
||||
)
|
||||
|
||||
const (
|
||||
ModelTypeInference = 100 // 推理模型
|
||||
)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package public
|
||||
|
||||
const (
|
||||
TableNameModel = "asynch_models" // 模型表
|
||||
TableNamePromptConfig = "prompts_model_prompt" // 模型提示词配置表(prompts-core)
|
||||
TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表
|
||||
TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表
|
||||
DbNameModelGateway = "model_gateway" //数据库名称
|
||||
)
|
||||
|
||||
const (
|
||||
TableNameModel = "asynch_models" // 模型表
|
||||
TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表
|
||||
TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表
|
||||
TableNameProviderProtocol = "prompts_provider_protocol"
|
||||
)
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/service"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type session struct{}
|
||||
|
||||
// Prompt 提示词配置控制器
|
||||
var Session = new(session)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
|
||||
return service.Session.SessionCallback(ctx, req)
|
||||
}
|
||||
28
controller/prompt_compose_controller.go
Normal file
28
controller/prompt_compose_controller.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/model/dto"
|
||||
promptService "prompts-core/service/prompt"
|
||||
)
|
||||
|
||||
type prompt struct{}
|
||||
|
||||
// Prompt 提示词配置控制器
|
||||
var Prompt = new(prompt)
|
||||
|
||||
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
|
||||
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
|
||||
return promptService.ComposeMessages(ctx, req)
|
||||
}
|
||||
|
||||
// Callback model-gateway 提示词回调
|
||||
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) {
|
||||
err = promptService.Callback(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetComposeTask 查询拼接任务结果
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/service"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type prompt struct{}
|
||||
|
||||
// Prompt 提示词配置控制器
|
||||
var Prompt = new(prompt)
|
||||
|
||||
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果
|
||||
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
|
||||
return service.Prompt.ComposeMessages(ctx, req)
|
||||
}
|
||||
|
||||
// ComposeMessagesCallback model-gateway 提示词回调
|
||||
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Prompt.Callback(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetComposeTask 查询拼接任务结果
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||
return service.Prompt.GetComposeTask(ctx, req.TaskId)
|
||||
}
|
||||
|
||||
// CreatePrompt 添加配置(默认启用)
|
||||
func (c *prompt) CreatePrompt(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
|
||||
return service.Prompt.Create(ctx, req)
|
||||
}
|
||||
|
||||
// UpdatePrompt 更新配置
|
||||
func (c *prompt) UpdatePrompt(ctx context.Context, req *dto.UpdatePromptReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Prompt.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// DeletePrompt 删除配置
|
||||
func (c *prompt) DeletePrompt(ctx context.Context, req *dto.DeletePromptReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Prompt.Delete(ctx, req.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// GetPrompt 获取配置详情
|
||||
func (c *prompt) GetPrompt(ctx context.Context, req *dto.GetPromptReq) (res *dto.GetPromptRes, err error) {
|
||||
m, err := service.Prompt.Get(ctx, req.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetPromptRes{Prompt: m}, nil
|
||||
}
|
||||
|
||||
// ListPrompt 配置列表
|
||||
func (c *prompt) ListPrompt(ctx context.Context, req *dto.ListPromptReq) (res *dto.ListPromptRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
}
|
||||
}
|
||||
var modelTypeID *int
|
||||
modelType := ""
|
||||
if req != nil {
|
||||
modelTypeID = req.ModelTypeId
|
||||
modelType = req.ModelType
|
||||
}
|
||||
|
||||
list, total, err := service.Prompt.List(ctx, pageNum, pageSize, modelTypeID, modelType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListPromptRes{
|
||||
List: list,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
36
controller/prompt_session_controller.go
Normal file
36
controller/prompt_session_controller.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// ============================================
|
||||
// controller/session.go
|
||||
// ============================================
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
sessionService "prompts-core/service/session"
|
||||
)
|
||||
|
||||
type session struct{}
|
||||
|
||||
var Session = new(session)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||
return sessionService.Callback(ctx, req)
|
||||
}
|
||||
|
||||
// GetHistoryList 获取历史列表(前端列表)
|
||||
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
|
||||
return sessionService.GetHistoryList(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteMessages 批量删除消息
|
||||
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
|
||||
return sessionService.DeleteMessages(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteSession 删除整个会话
|
||||
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
||||
return sessionService.DeleteSession(ctx, req)
|
||||
}
|
||||
@@ -2,113 +2,121 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
)
|
||||
|
||||
var ComposeSession = &composeSessionDao{}
|
||||
|
||||
type composeSessionDao struct{}
|
||||
|
||||
func (d *composeSessionDao) Insert(ctx context.Context, m *entity.ComposeSession) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).Data(m).Insert()
|
||||
// Insert 插入
|
||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
Insert(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) Update(ctx context.Context, m *entity.ComposeSession) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
Where(entity.ComposeSessionCol.Id, m.Id).
|
||||
Data(m).
|
||||
// Update 更新
|
||||
func (d *composeSessionDao) Update(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Data(&req).
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) List(ctx context.Context, page, size int, where map[string]any) (list []*entity.ComposeSession, total int, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
Where("deleted_at IS NULL")
|
||||
|
||||
// 动态拼接查询条件
|
||||
for k, v := range where {
|
||||
model = model.Where(k, v)
|
||||
// List 查询编排会话列表
|
||||
func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession, page, size int, fields ...string) (list []*entity.ComposeSession, total int, err error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
total, err = model.Count()
|
||||
if size <= 0 {
|
||||
size = 10
|
||||
}
|
||||
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
Fields(fields).
|
||||
OmitEmpty()
|
||||
model.Where(entity.ComposeSessionCol.Creator, req.Creator)
|
||||
model.Where(entity.ComposeSessionCol.SessionId, req.SessionId)
|
||||
model.OrderDesc(entity.ComposeSessionCol.CreatedAt)
|
||||
model.Page(page, size)
|
||||
r, total, err := model.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
err = model.Order("created_at DESC").
|
||||
Page(page, size).
|
||||
Scan(&list)
|
||||
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) GetListBySessionId(ctx context.Context, sessionId string, limit int) ([]*entity.ComposeSession, error) {
|
||||
var sessions []*entity.ComposeSession
|
||||
err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||
WhereNull(entity.ComposeSessionCol.DeletedAt).
|
||||
OrderDesc(entity.ComposeSessionCol.Id).
|
||||
Limit(limit).
|
||||
Scan(&sessions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 反转成时间正序
|
||||
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
|
||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.ComposeSession, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
Where(entity.ComposeSessionCol.Id, Id).
|
||||
One()
|
||||
// Get 查询编排会话
|
||||
func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, fields ...string) (m *entity.ComposeSession, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) GetBySessionId(ctx context.Context, sessionId string) (m *entity.ComposeSession, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
// Delete 删除编排会话
|
||||
func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// ListByIds 根据 ID 列表批量查询
|
||||
func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||
Where(entity.ComposeSessionCol.Creator, creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||
One()
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *composeSessionDao) DeleteBySessionId(ctx context.Context, sessionId string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
|
||||
// DeleteByIds 批量删除编排会话
|
||||
func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||
Where(entity.ComposeSessionCol.Creator, creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||
Data(map[string]any{
|
||||
entity.ComposeSessionCol.DeletedAt: "NOW()",
|
||||
}).
|
||||
Update()
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -2,47 +2,54 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ComposeTask = &composeTaskDao{}
|
||||
|
||||
type composeTaskDao struct{}
|
||||
|
||||
func (d *composeTaskDao) Insert(ctx context.Context, m *entity.ComposeTask) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).Data(m).Insert()
|
||||
// Insert 插入
|
||||
func (d *composeTaskDao) Insert(ctx context.Context, req *entity.ComposeTask) (id int64, err error) {
|
||||
var m = new(entity.ComposeTask)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *composeTaskDao) GetByTaskId(ctx context.Context, taskId string) (m *entity.ComposeTask, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).
|
||||
Where(entity.ComposeTaskCol.TaskId, taskId).
|
||||
One()
|
||||
// Get 获取
|
||||
func (d *composeTaskDao) Get(ctx context.Context, req *entity.ComposeTask, fields ...string) (m *entity.ComposeTask, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *composeTaskDao) UpdateByTaskId(ctx context.Context, taskId string, data map[string]any) (rows int64, err error) {
|
||||
data[entity.ComposeTaskCol.Updater] = ""
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).
|
||||
Where(entity.ComposeTaskCol.TaskId, taskId).
|
||||
Data(data).
|
||||
// Update 更新
|
||||
func (d *composeTaskDao) Update(ctx context.Context, req *entity.ComposeTask) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
|
||||
OmitEmpty().
|
||||
Data(&req).
|
||||
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
)
|
||||
|
||||
var Model = &modelDao{}
|
||||
|
||||
type modelDao struct{}
|
||||
|
||||
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.ModelName, modelName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
||||
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// GetBySuperAdmin 查询超级管理员(tenant_id=1)的模型
|
||||
func (d *modelDao) GetBySuperAdmin(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
||||
sql := fmt.Sprintf("SELECT * FROM %s WHERE model_name = ? AND tenant_id = 1 AND deleted_at IS NULL LIMIT 1", public.TableNameModel)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(r) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
err = r[0].Struct(&m)
|
||||
return
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Prompt = &promptDao{}
|
||||
|
||||
type promptDao struct{}
|
||||
|
||||
func (d *promptDao) Insert(ctx context.Context, m *entity.PromptConfig) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Data(m).Insert()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *promptDao) UpdateByID(ctx context.Context, id int64, data map[string]any) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
data[entity.PromptConfigCol.Updater] = ""
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
|
||||
Where(entity.PromptConfigCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *promptDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
|
||||
Where(entity.PromptConfigCol.Id, id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *promptDao) GetByID(ctx context.Context, id int64) (m *entity.PromptConfig, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
|
||||
Where(entity.PromptConfigCol.Id, id).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *promptDao) GetLatestEnabledByModelTypeID(ctx context.Context, modelTypeID int) (m *entity.PromptConfig, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
|
||||
Where("deleted_at IS NULL").
|
||||
Where(entity.PromptConfigCol.ModelTypeId, modelTypeID).
|
||||
Where(entity.PromptConfigCol.Enabled, 1).
|
||||
OrderDesc(entity.PromptConfigCol.CreatedAt).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *promptDao) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Where("deleted_at IS NULL").OrderDesc(entity.PromptConfigCol.CreatedAt)
|
||||
if modelTypeID != nil && *modelTypeID > 0 {
|
||||
model = model.Where(entity.PromptConfigCol.ModelTypeId, *modelTypeID)
|
||||
}
|
||||
if modelTypeLike != "" {
|
||||
model = model.WhereLike(entity.PromptConfigCol.ModelType, "%"+modelTypeLike+"%")
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
98
dao/provider_protocol_dao.go
Normal file
98
dao/provider_protocol_dao.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ProviderProtocol = &providerProtocolDao{}
|
||||
|
||||
type providerProtocolDao struct{}
|
||||
|
||||
// Insert 新增协议配置
|
||||
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
|
||||
var m = new(entity.ProviderProtocol)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Get 查询协议配置
|
||||
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.ProviderProtocolCol.Id, req.Id).
|
||||
Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
|
||||
Where(entity.ProviderProtocolCol.Status, 1).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&res)
|
||||
return
|
||||
}
|
||||
|
||||
// List 列表查询
|
||||
func (d *providerProtocolDao) List(ctx context.Context, req *entity.ProviderProtocol, page, size int, fields ...string) (list []*entity.ProviderProtocol, total int, err error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if size <= 0 {
|
||||
size = 10
|
||||
}
|
||||
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
Fields(fields).
|
||||
OmitEmpty()
|
||||
model.Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName)
|
||||
model.Where(entity.ProviderProtocolCol.Status, req.Status)
|
||||
model.OrderDesc(entity.ProviderProtocolCol.CreatedAt)
|
||||
model.Page(page, size)
|
||||
r, total, err := model.AllAndCount(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// Update 更新协议配置
|
||||
func (d *providerProtocolDao) Update(ctx context.Context, req *entity.ProviderProtocol) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
OmitEmpty().
|
||||
Where(entity.ProviderProtocolCol.Id, req.Id).
|
||||
Data(req).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Delete 软删除协议配置
|
||||
func (d *providerProtocolDao) Delete(ctx context.Context, id int64) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
|
||||
Where(entity.ProviderProtocolCol.Id, id).
|
||||
Data(map[string]any{
|
||||
entity.ProviderProtocolCol.DeletedAt: "NOW()",
|
||||
}).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
16
go.mod
16
go.mod
@@ -1,17 +1,12 @@
|
||||
module prompts-core
|
||||
|
||||
go 1.26.0
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
gitea.com/red-future/common v0.0.19
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
|
||||
github.com/gogf/gf/v2 v2.10.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
|
||||
github.com/gogf/gf/v2 v2.10.2
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -68,7 +63,6 @@ require (
|
||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/gjson v1.19.0
|
||||
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||
|
||||
22
go.sum
22
go.sum
@@ -1,6 +1,6 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
|
||||
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -77,16 +77,16 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 h1:39+jbTenm7KBj4hO2C8ANAxVHpX/7OuRDs1VcGC9ylA=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0/go.mod h1:B0s0fVzn0W220E8UTpSGzrrGKsop5KcB90twBeLCiz0=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0 h1:N/F9CuDdUZLoM1nVRqrDE/33pDZuhVxpNY4wYdeIaBs=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0/go.mod h1:x6uoJGfZOtirIRQls8xUlYzC6f7T/eULPUa9er368X0=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 h1:u8EpP24GkprogROnJ7htMov9Fc66pTP1eVYrWxiCYOs=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2/go.mod h1:GmvM3r8GVByVMi4RD2+MCs5+CfxVXPMeT8mVDkAaAXE=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 h1:iTQegT+lEg/wDKvj2mi3W1wrdrwFarjokf88EXVVgu4=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2/go.mod h1:ZRw3GNz5cq4uYrW4TPSVyrYWaoqzujKdWro/AOcGBaE=
|
||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
|
||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
|
||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
|
||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
|
||||
github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs=
|
||||
github.com/gogf/gf/v2 v2.10.0/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
||||
github.com/gogf/gf/v2 v2.10.2 h1:46IO0Uc8e85/FqdftJFskfDejJLBL0JBnGS5qOftUu8=
|
||||
github.com/gogf/gf/v2 v2.10.2/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
@@ -288,12 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
||||
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
|
||||
12
main.go
12
main.go
@@ -4,13 +4,12 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"prompts-core/controller"
|
||||
"syscall"
|
||||
|
||||
"prompts-core/controller"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
_ "gitea.com/red-future/common/swagger"
|
||||
"gitea.redpowerfuture.com/red-future/common/http"
|
||||
"gitea.redpowerfuture.com/red-future/common/jaeger"
|
||||
_ "gitea.redpowerfuture.com/red-future/common/swagger"
|
||||
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
|
||||
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -20,14 +19,13 @@ func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
defer jaeger.ShutDown(ctx)
|
||||
|
||||
// 注册路由
|
||||
http.RouteRegister([]interface{}{
|
||||
controller.Prompt,
|
||||
controller.Session,
|
||||
})
|
||||
|
||||
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 http server
|
||||
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role" dc:"角色:system/user/assistant"`
|
||||
Content any `json:"content" dc:"消息内容"`
|
||||
}
|
||||
|
||||
type ComposeMessagesReq struct {
|
||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||
UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
|
||||
}
|
||||
|
||||
type ComposeMessagesRes struct {
|
||||
Messages any `json:"messages,omitempty" dc:"最终消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
Text string `json:"text" dc:"文本结果"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
type GetComposeTaskReq struct {
|
||||
g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"`
|
||||
TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
|
||||
}
|
||||
|
||||
type GetComposeTaskRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"业务状态"`
|
||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||
Messages any `json:"messages" dc:"最终消息数组"`
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
type SessionCallbackReq struct {
|
||||
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
|
||||
Text string `json:"text" dc:"文本结果"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
}
|
||||
53
model/dto/prompt_compose_dto.go
Normal file
53
model/dto/prompt_compose_dto.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
type ComposeMessagesReq struct {
|
||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||
NodeId string `p:"nodeId" json:"nodeId" dc:"节点ID"`
|
||||
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
|
||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
|
||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||
}
|
||||
|
||||
// ConsultItem 单个附件
|
||||
type ConsultItem struct {
|
||||
Type string `json:"type" dc:"附件类型:image/video/audio"`
|
||||
Url string `json:"url" dc:"附件地址"`
|
||||
}
|
||||
type ComposeMessagesRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
}
|
||||
|
||||
type CallbackRes struct {
|
||||
}
|
||||
|
||||
type GetComposeTaskReq struct {
|
||||
g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"`
|
||||
TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
|
||||
}
|
||||
|
||||
type GetComposeTaskRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"业务状态"`
|
||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||
Messages map[string]any `json:"messages" dc:"最终消息数组"`
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// CreatePromptReq 添加提示词配置(默认启用)
|
||||
type CreatePromptReq struct {
|
||||
g.Meta `path:"/createPrompt" method:"post" tags:"提示词管理" summary:"创建提示词配置" dc:"创建新的模型提示词配置(默认启用)"`
|
||||
ModelTypeId int `p:"modelTypeId" json:"modelTypeId" v:"required#modelTypeId不能为空" dc:"模型分类ID"`
|
||||
ModelType string `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类别/模型类型"`
|
||||
PromptInfo any `p:"promptInfo" json:"promptInfo" v:"required#promptInfo不能为空" dc:"数据库定义的表单规则数据(JSON)"`
|
||||
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" v:"required#responseJsonSchema不能为空" dc:"模型返回表单 JSON 格式约束"`
|
||||
// Version 预留字段:先不使用,但表结构保留
|
||||
Version string `p:"version" json:"version" dc:"版本号(预留)"`
|
||||
}
|
||||
|
||||
type CreatePromptRes struct {
|
||||
ID int64 `json:"id,string" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// UpdatePromptReq 更新提示词配置
|
||||
type UpdatePromptReq struct {
|
||||
g.Meta `path:"/updatePrompt" method:"put" tags:"提示词管理" summary:"更新提示词配置" dc:"更新指定ID的提示词配置"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
|
||||
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID(可选更新)"`
|
||||
ModelType *string `p:"modelType" json:"modelType" dc:"模型类别/模型类型(可选更新)"`
|
||||
PromptInfo any `p:"promptInfo" json:"promptInfo" dc:"数据库定义的表单规则数据(JSON)(可选更新)"`
|
||||
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" dc:"模型返回表单 JSON 格式约束(可选更新)"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||
Version *string `p:"version" json:"version" dc:"版本号(预留,可选更新)"`
|
||||
}
|
||||
|
||||
// DeletePromptReq 删除提示词配置
|
||||
type DeletePromptReq struct {
|
||||
g.Meta `path:"/deletePrompt" method:"delete" tags:"提示词管理" summary:"删除提示词配置" dc:"删除指定ID的提示词配置"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// GetPromptReq 获取提示词配置详情
|
||||
type GetPromptReq struct {
|
||||
g.Meta `path:"/getPrompt" method:"get" tags:"提示词管理" summary:"获取提示词配置" dc:"根据ID获取提示词配置详情"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
}
|
||||
|
||||
type GetPromptRes struct {
|
||||
Prompt any `json:"prompt" dc:"提示词配置详情"`
|
||||
}
|
||||
|
||||
// ListPromptReq 配置列表
|
||||
type ListPromptReq struct {
|
||||
g.Meta `path:"/listPrompt" method:"post" tags:"提示词管理" summary:"提示词配置列表" dc:"分页获取提示词配置列表"`
|
||||
Page *beans.Page `p:"page" json:"page" dc:"分页参数"`
|
||||
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID(可选)"`
|
||||
ModelType string `p:"modelType" json:"modelType" dc:"模型类型名称(可选,模糊查询)"`
|
||||
}
|
||||
|
||||
type ListPromptRes struct {
|
||||
List any `json:"list" dc:"列表数据"`
|
||||
Total int64 `json:"total" dc:"总数"`
|
||||
}
|
||||
80
model/dto/prompt_session_dto.go
Normal file
80
model/dto/prompt_session_dto.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
// HistoryRound 一轮对话
|
||||
type HistoryRound struct {
|
||||
Id int64 `json:"id" dc:"记录ID"`
|
||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
User map[string]any `json:"user" dc:"用户消息"`
|
||||
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
||||
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
|
||||
// SessionCallbackReq 会话回调请求
|
||||
type SessionCallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
||||
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
// SessionCallbackRes 会话回调响应
|
||||
type SessionCallbackRes struct {
|
||||
Status bool `json:"status" dc:"状态"`
|
||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||
}
|
||||
|
||||
// GetHistoryListReq 获取历史列表请求(前端)
|
||||
type GetHistoryListReq struct {
|
||||
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
|
||||
Page int `json:"page" d:"1" dc:"页码"`
|
||||
Size int `json:"size" d:"10" dc:"每页条数"`
|
||||
}
|
||||
|
||||
// GetHistoryListRes 获取历史列表响应
|
||||
type GetHistoryListRes struct {
|
||||
List []HistoryRound `json:"list" dc:"历史列表"`
|
||||
Total int `json:"total" dc:"总数"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
|
||||
type GetHistoryMessagesReq struct {
|
||||
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesRes 获取历史消息响应
|
||||
type GetHistoryMessagesRes struct {
|
||||
Messages []FlatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type FlatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// DeleteMessagesReq 批量删除消息请求
|
||||
type DeleteMessagesReq struct {
|
||||
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
|
||||
}
|
||||
|
||||
// DeleteMessagesRes 批量删除消息响应
|
||||
type DeleteMessagesRes struct {
|
||||
Ok bool `json:"ok" dc:"是否成功"`
|
||||
}
|
||||
|
||||
// DeleteSessionReq 删除整个会话请求
|
||||
type DeleteSessionReq struct {
|
||||
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
}
|
||||
|
||||
// DeleteSessionRes 删除整个会话响应
|
||||
type DeleteSessionRes struct {
|
||||
Ok bool `json:"ok" dc:"是否成功"`
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type asynchModelCol struct {
|
||||
beans.SQLBaseCol
|
||||
ModelName string
|
||||
ModelType string
|
||||
BaseURL string
|
||||
HttpMethod string
|
||||
HeadMsg string
|
||||
FormJSON string
|
||||
RequestMapping string
|
||||
ResponseMapping string
|
||||
ResponseBody string
|
||||
TokenMapping string
|
||||
Prompt string
|
||||
IsPrivate string
|
||||
IsChatModel string
|
||||
ApiKey string
|
||||
Enabled string
|
||||
MaxConcurrency string
|
||||
QueueLimit string
|
||||
TimeoutSeconds string
|
||||
ExpectedSeconds string
|
||||
RetryTimes string
|
||||
RetryQueueMaxSecs string
|
||||
AutoCleanSeconds string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var AsynchModelCol = asynchModelCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelName: "model_name",
|
||||
ModelType: "model_type",
|
||||
BaseURL: "base_url",
|
||||
HttpMethod: "http_method",
|
||||
HeadMsg: "head_msg",
|
||||
FormJSON: "form_json",
|
||||
RequestMapping: "request_mapping",
|
||||
ResponseMapping: "response_mapping",
|
||||
ResponseBody: "response_body",
|
||||
TokenMapping: "token_mapping",
|
||||
Prompt: "prompt",
|
||||
IsPrivate: "is_private",
|
||||
IsChatModel: "is_chat_model",
|
||||
ApiKey: "api_key",
|
||||
Enabled: "enabled",
|
||||
MaxConcurrency: "max_concurrency",
|
||||
QueueLimit: "queue_limit",
|
||||
TimeoutSeconds: "timeout_seconds",
|
||||
ExpectedSeconds: "expected_seconds",
|
||||
RetryTimes: "retry_times",
|
||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
// AsynchModel 异步模型配置
|
||||
type AsynchModel struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
||||
Form any `orm:"form_json" json:"form"`
|
||||
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type composeSessionCol struct {
|
||||
beans.SQLBaseCol
|
||||
SessionId string
|
||||
RequestContent string
|
||||
ResponseContent string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var ComposeSessionCol = composeSessionCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
SessionId: "session_id",
|
||||
RequestContent: "request_content",
|
||||
ResponseContent: "response_content",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
type ComposeSession struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
SessionId string `orm:"session_id" json:"sessionId"`
|
||||
RequestContent any `orm:"request_content" json:"requestContent"`
|
||||
ResponseContent any `orm:"response_content" json:"responseContent"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type composeTaskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskId string
|
||||
ModelName string
|
||||
SkillName string
|
||||
LimitWords string
|
||||
RequestPayload string
|
||||
CallbackPayload string
|
||||
ModelResult string
|
||||
Messages string
|
||||
Status string
|
||||
ErrorMessage string
|
||||
}
|
||||
|
||||
var ComposeTaskCol = composeTaskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskId: "task_id",
|
||||
ModelName: "model_name",
|
||||
SkillName: "skill_name",
|
||||
LimitWords: "limit_words",
|
||||
RequestPayload: "request_payload",
|
||||
CallbackPayload: "callback_payload",
|
||||
ModelResult: "model_result",
|
||||
Messages: "messages",
|
||||
Status: "status",
|
||||
ErrorMessage: "error_message",
|
||||
}
|
||||
|
||||
type ComposeTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
TaskId string `orm:"task_id" json:"taskId"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
SkillName string `orm:"skill_name" json:"skillName"`
|
||||
LimitWords int `orm:"limit_words" json:"limitWords"`
|
||||
RequestPayload any `orm:"request_payload" json:"requestPayload"`
|
||||
CallbackPayload any `orm:"callback_payload" json:"callbackPayload"`
|
||||
ModelResult any `orm:"model_result" json:"modelResult"`
|
||||
Messages any `orm:"messages" json:"messages"`
|
||||
Status string `orm:"status" json:"status"`
|
||||
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type promptConfigCol struct {
|
||||
beans.SQLBaseCol
|
||||
ModelTypeId string
|
||||
ModelType string
|
||||
PromptInfo string
|
||||
ResponseJsonSchema string
|
||||
Enabled string
|
||||
Version string
|
||||
}
|
||||
|
||||
var PromptConfigCol = promptConfigCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelTypeId: "model_type_id",
|
||||
ModelType: "model_type",
|
||||
PromptInfo: "prompt_info",
|
||||
ResponseJsonSchema: "response_json_schema",
|
||||
Enabled: "enabled",
|
||||
Version: "version",
|
||||
}
|
||||
|
||||
// PromptConfig 模型提示词配置
|
||||
//
|
||||
// 说明:
|
||||
// - prompt_info 使用 JSONB 保存(对外用 json 传输)
|
||||
// - response_json_schema 为模型返回 JSON 格式约束
|
||||
// - enabled:1启用/0禁用
|
||||
type PromptConfig struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelTypeId int `orm:"model_type_id" json:"modelTypeId"`
|
||||
ModelType string `orm:"model_type" json:"modelType"`
|
||||
PromptInfo any `orm:"prompt_info" json:"promptInfo"`
|
||||
ResponseJsonSchema any `orm:"response_json_schema" json:"responseJsonSchema"`
|
||||
Enabled int `orm:"enabled" json:"enabled"`
|
||||
Version string `orm:"version" json:"version"`
|
||||
}
|
||||
30
model/entity/prompts_compose_session.go
Normal file
30
model/entity/prompts_compose_session.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package entity
|
||||
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
type ComposeSession struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
SessionId string `orm:"session_id" json:"sessionId"`
|
||||
NodeId string `orm:"node_id" json:"nodeId"`
|
||||
RequestContent map[string]any `orm:"request_content" json:"requestContent"`
|
||||
ResponseContent map[string]any `orm:"response_content" json:"responseContent"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
type composeSessionCol struct {
|
||||
beans.SQLBaseCol
|
||||
SessionId string
|
||||
NodeId string
|
||||
RequestContent string
|
||||
ResponseContent string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var ComposeSessionCol = composeSessionCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
SessionId: "session_id",
|
||||
NodeId: "node_id",
|
||||
RequestContent: "request_content",
|
||||
ResponseContent: "response_content",
|
||||
Remark: "remark",
|
||||
}
|
||||
51
model/entity/prompts_compose_task.go
Normal file
51
model/entity/prompts_compose_task.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package entity
|
||||
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
type ComposeTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
TaskId string `orm:"task_id" json:"taskId"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
SkillName string `orm:"skill_name" json:"skillName"`
|
||||
BuildType int `orm:"build_type" json:"buildType"`
|
||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
||||
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
||||
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
|
||||
Status string `orm:"status" json:"status"`
|
||||
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||
FileType string `orm:"file_type" json:"fileType"`
|
||||
}
|
||||
|
||||
type composeTaskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskId string
|
||||
ModelName string
|
||||
SkillName string
|
||||
BuildType string
|
||||
CallbackUrl string
|
||||
GatewayState string
|
||||
RequestPayload string
|
||||
ResultJson string
|
||||
Status string
|
||||
ErrorMessage string
|
||||
OssFile string
|
||||
FileType string
|
||||
}
|
||||
|
||||
var ComposeTaskCol = composeTaskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskId: "task_id",
|
||||
ModelName: "model_name",
|
||||
SkillName: "skill_name",
|
||||
BuildType: "build_type",
|
||||
CallbackUrl: "callback_url",
|
||||
GatewayState: "gateway_state",
|
||||
RequestPayload: "request_payload",
|
||||
ResultJson: "result_json",
|
||||
Status: "status",
|
||||
ErrorMessage: "error_message",
|
||||
OssFile: "oss_file",
|
||||
FileType: "file_type",
|
||||
}
|
||||
49
model/entity/prompts_provider_protocol.go
Normal file
49
model/entity/prompts_provider_protocol.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package entity
|
||||
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
// ProviderProtocol 模型协议映射配置
|
||||
type ProviderProtocol struct {
|
||||
beans.SQLBaseDO `orm:",inherit"`
|
||||
// 业务字段
|
||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||
TargetField string `orm:"target_field" json:"targetField"`
|
||||
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
|
||||
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
|
||||
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
|
||||
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
|
||||
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
|
||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||
Status int `orm:"status" json:"status"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
// providerProtocolCol 列名
|
||||
type providerProtocolCol struct {
|
||||
beans.SQLBaseCol
|
||||
ProviderName string
|
||||
TargetField string
|
||||
MergeOrder string
|
||||
RoleMapping string
|
||||
ContentMapping string
|
||||
Capabilities string
|
||||
RequestTemplate string
|
||||
SystemPromptTemplate string
|
||||
Status string
|
||||
Remark string
|
||||
}
|
||||
|
||||
// ProviderProtocolCol 列名常量
|
||||
var ProviderProtocolCol = providerProtocolCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ProviderName: "provider_name",
|
||||
TargetField: "target_field",
|
||||
MergeOrder: "merge_order",
|
||||
RoleMapping: "role_mapping",
|
||||
ContentMapping: "content_mapping",
|
||||
Capabilities: "capabilities",
|
||||
RequestTemplate: "request_template",
|
||||
SystemPromptTemplate: "system_prompt_template",
|
||||
Status: "status",
|
||||
Remark: "remark",
|
||||
}
|
||||
@@ -1,145 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// 获取请求模型的提示词
|
||||
func GetModelPrompt(ctx context.Context, Type int) string {
|
||||
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
|
||||
}
|
||||
|
||||
// 获取构建提示词
|
||||
func GetBuildPrompt(ctx context.Context, Type int) string {
|
||||
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
|
||||
}
|
||||
|
||||
// buildInferenceRequest 构建返回请求
|
||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||
messages := []map[string]any{}
|
||||
switch req.BuildType {
|
||||
//构建提示词请求
|
||||
case 1:
|
||||
//1. 构建系统提示词
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": promptBuild(ctx, req, model),
|
||||
})
|
||||
|
||||
// 2. 构建历史会话提示词
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
content := gconv.String(msg["content"])
|
||||
if role != "user" && role != "assistant" {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
// 3. 当前用户问题(原来的最后一条)
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": buildUserPrompt(ctx, req, GetModelPrompt(ctx, model.ModelType)),
|
||||
})
|
||||
//构建节点请求
|
||||
case 2:
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": NodeBuid(ctx, req),
|
||||
})
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
}
|
||||
// 构建请求体
|
||||
return map[string]any{
|
||||
"modelName": chatModel.ModelName,
|
||||
"bizName": "prompts-core",
|
||||
"callbackUrl": "/prompt/callback",
|
||||
"requestPayload": map[string]any{
|
||||
"model": chatModel.ModelName,
|
||||
"messages": messages,
|
||||
"stream": false,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 构建用户提示词
|
||||
// ============================================
|
||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||
payload := map[string]any{
|
||||
"model": req.ModelName,
|
||||
//数据库提示信息
|
||||
"promptInfo": prompt,
|
||||
// 系统表单
|
||||
"form": req.Form,
|
||||
// 用户表单
|
||||
"userForm": req.UserForm,
|
||||
//文件url
|
||||
"userFiles": req.UserFiles,
|
||||
//解读文件(只支持可读类型 如:xml,json,yaml)
|
||||
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
|
||||
//skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
|
||||
"skills": SkillMdContent(ctx, req.SkillName),
|
||||
}
|
||||
return mustMarshal(payload)
|
||||
}
|
||||
|
||||
// promptBuild 提示词构建
|
||||
func promptBuild(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) string {
|
||||
// 1. 从配置文件读取提示词模板
|
||||
promptTpl := GetBuildPrompt(ctx, req.BuildType)
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
// 2. 构建字段映射说明
|
||||
mappingBytes, _ := json.Marshal(model.RequestMapping)
|
||||
mappingStr := string(mappingBytes)
|
||||
|
||||
var mapping map[string]string
|
||||
_ = json.Unmarshal(mappingBytes, &mapping)
|
||||
|
||||
var fieldDesc strings.Builder
|
||||
for key, path := range mapping {
|
||||
fieldDesc.WriteString(fmt.Sprintf("- %s → %s\n", key, path))
|
||||
}
|
||||
|
||||
// 3. 拼接 UserForm 全文(必须完整阅读)
|
||||
var userFormContent strings.Builder
|
||||
for k, v := range req.UserForm {
|
||||
userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v))
|
||||
}
|
||||
userFormFullText := strings.TrimSuffix(userFormContent.String(), ";")
|
||||
|
||||
// 4. 双表单信息
|
||||
formInfo := fmt.Sprintf(`
|
||||
【系统表单(系统提示词/参数)】
|
||||
%s
|
||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||||
%s
|
||||
`, formToJSON(req.Form), userFormFullText)
|
||||
// 5. 格式化最终提示词(替换配置里的 %s)
|
||||
return fmt.Sprintf(promptTpl, mappingStr, fieldDesc.String(), formInfo)
|
||||
}
|
||||
|
||||
// NodeBuid 节点构建
|
||||
func NodeBuid(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||
promptTpl := GetBuildPrompt(ctx, req.BuildType)
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
formStr := formToJSON(req.Form)
|
||||
userFormStr := formToJSON(req.UserForm)
|
||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||||
}
|
||||
@@ -1,414 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 核心业务流程
|
||||
// ============================================
|
||||
|
||||
// ComposeMessages 拼接提示词主流程
|
||||
func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
var (
|
||||
epicycleId int64
|
||||
taskID string
|
||||
history []map[string]any
|
||||
message map[string]any
|
||||
err error
|
||||
taskRecord *entity.ComposeTask
|
||||
)
|
||||
// 获取模型信息
|
||||
chatModel, model, err := s.GetModelMessage(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 根据构建类型进行判断处理
|
||||
switch req.BuildType {
|
||||
//提示词构建
|
||||
case 1:
|
||||
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
|
||||
//1. 获取历史会话
|
||||
history, err = Session.GetHistoryMessages(ctx, req.SessionId)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
||||
history = nil // 出错就用空的,不影响主流程
|
||||
}
|
||||
// 重试循环
|
||||
for attempt := 0; attempt <= maxRetryTimes; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
|
||||
}
|
||||
// 2. 调用推理模型
|
||||
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, history)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 保存记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
RequestPayload: mustMarshal(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 4. 等待结果
|
||||
taskRecord, err = s.waitForResult(ctx, taskID)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
// 校验结果
|
||||
message = s.parsePromptBuild(taskRecord, chatModel)
|
||||
if message != nil && isMessageValid(message) {
|
||||
break
|
||||
}
|
||||
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
|
||||
message = nil
|
||||
}
|
||||
if message == nil {
|
||||
return nil, errors.New("推理模型调用失败,请稍后再试")
|
||||
}
|
||||
//5.创建会话记录
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: message,
|
||||
})
|
||||
//节点构建
|
||||
case 2:
|
||||
//1. 调用推理模型
|
||||
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//2. 保存相关记录
|
||||
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
RequestPayload: mustMarshal(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
//5. 等待结果
|
||||
taskRecord, err := s.waitForResult(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
message = s.parseNodeBuild(taskRecord)
|
||||
default:
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
Remark: req.Cause,
|
||||
})
|
||||
return &dto.ComposeMessagesRes{
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
}
|
||||
return &dto.ComposeMessagesRes{
|
||||
Messages: message,
|
||||
EpicycleId: epicycleId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
|
||||
// ============ 先查任务是否存在 ============
|
||||
task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if task == nil {
|
||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||
}
|
||||
// ============ 根据状态区分处理 ============
|
||||
if req.State == 3 {
|
||||
// 失败:直接更新状态
|
||||
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
|
||||
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
|
||||
entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg,
|
||||
})
|
||||
return err
|
||||
}
|
||||
// ======================================
|
||||
// 成功:解析模型输出
|
||||
result, err := parseOutput(req.Text)
|
||||
if err != nil {
|
||||
_, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
|
||||
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
|
||||
entity.ComposeTaskCol.ErrorMessage: err.Error(),
|
||||
})
|
||||
if updateErr != nil {
|
||||
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ============ result 可能为 nil ============
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
// =======================================
|
||||
|
||||
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
|
||||
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess,
|
||||
entity.ComposeTaskCol.Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// 如果 Messages 是字符串,反序列化为 JSON 数组
|
||||
messages := record.Messages
|
||||
if str, ok := messages.(string); ok && str != "" {
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
messages = parsed
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModelMessage 获取模型信息
|
||||
func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
||||
// 1. 获取当前用户的会话模型
|
||||
chatModel, err := dao.Model.GetByIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
// 2. 获取要构建的模型信息
|
||||
model, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
|
||||
}
|
||||
return chatModel, model, nil
|
||||
}
|
||||
|
||||
// callInferenceModel 调用推理模型
|
||||
func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
|
||||
// 构建推理模型请求
|
||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建网关任务
|
||||
taskID, err := createGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
||||
}
|
||||
|
||||
if taskID == "" {
|
||||
return "", errors.New("网关未返回taskId")
|
||||
}
|
||||
|
||||
return taskID, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 步骤6:等待结果
|
||||
// ============================================
|
||||
func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
||||
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
||||
deadline := time.Now().Add(timeout)
|
||||
|
||||
for {
|
||||
// ===================== 修复点 1:检查上下文是否取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 请求已被取消,直接返回,不继续查库
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 1. 查数据库
|
||||
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID)
|
||||
if err != nil {
|
||||
// ===================== 修复点 2:如果是上下文取消,直接返回 =====================
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if record != nil {
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 查网关状态
|
||||
state, err := queryGatewayTaskState(ctx, taskID)
|
||||
if err != nil {
|
||||
// 网关不可达不终止,继续轮询
|
||||
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
|
||||
} else {
|
||||
switch state {
|
||||
case 2: // 网关成功
|
||||
// 网关已成功,主动更新数据库
|
||||
if record != nil {
|
||||
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{
|
||||
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess,
|
||||
})
|
||||
}
|
||||
case 3: // 网关失败
|
||||
if record != nil {
|
||||
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{
|
||||
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
|
||||
entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败",
|
||||
})
|
||||
}
|
||||
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 超时检查
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// ===================== 修复点3:sleep 也要监听 ctx 取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parsePromptBuild 解析提示词构建结果(BuildType == 1)
|
||||
func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 解析 Messages
|
||||
var mapped map[string]any
|
||||
switch v := taskRecord.Messages.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &mapped)
|
||||
}
|
||||
case string:
|
||||
json.Unmarshal([]byte(v), &mapped)
|
||||
case map[string]any:
|
||||
mapped = v
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
json.Unmarshal(b, &mapped)
|
||||
}
|
||||
|
||||
// 2. 解析模型 ResponseMapping 获取 content 字段名
|
||||
contentField := "content" // 默认值
|
||||
if model != nil {
|
||||
var respMapping map[string]string
|
||||
switch v := model.ResponseMapping.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &respMapping)
|
||||
}
|
||||
case string:
|
||||
json.Unmarshal([]byte(v), &respMapping)
|
||||
case map[string]interface{}:
|
||||
respMapping = make(map[string]string)
|
||||
for k, val := range v {
|
||||
if s, ok := val.(string); ok {
|
||||
respMapping[k] = s
|
||||
}
|
||||
}
|
||||
}
|
||||
// 从映射中找到 content 对应的字段名
|
||||
for k, v := range respMapping {
|
||||
if strings.Contains(v, "content") {
|
||||
contentField = k
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 提取 content 的值
|
||||
contentStr, ok := mapped[contentField].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 4. 解析 content 内的 JSON
|
||||
var innerData map[string]any
|
||||
json.Unmarshal([]byte(contentStr), &innerData)
|
||||
|
||||
return innerData
|
||||
}
|
||||
|
||||
// parseNodeBuild 解析节点构建结果(BuildType == 2)
|
||||
func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
var result map[string]any
|
||||
switch v := taskRecord.Messages.(type) {
|
||||
case *gvar.Var:
|
||||
if v != nil {
|
||||
json.Unmarshal([]byte(v.String()), &result)
|
||||
}
|
||||
case string:
|
||||
json.Unmarshal([]byte(v), &result)
|
||||
case map[string]any:
|
||||
result = v
|
||||
default:
|
||||
b, _ := json.Marshal(v)
|
||||
json.Unmarshal(b, &result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -1,340 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 文件处理(配置直接内联 + zip 支持)
|
||||
// ============================================
|
||||
|
||||
// 允许的文本类 MIME 类型前缀
|
||||
var allowedMIMEPrefixes = []string{
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"application/toml",
|
||||
"application/x-httpd-php",
|
||||
"application/x-sh",
|
||||
"application/x-python",
|
||||
"application/x-perl",
|
||||
"application/x-ruby",
|
||||
}
|
||||
|
||||
// 禁止的文件扩展名
|
||||
var bannedExtensions = map[string]bool{
|
||||
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
||||
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
||||
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
||||
".wma": true, ".m4a": true,
|
||||
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
|
||||
".flv": true, ".webm": true,
|
||||
".tar": true, ".gz": true, ".rar": true, ".7z": true,
|
||||
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
|
||||
".class": true, ".pyc": true,
|
||||
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
||||
".ppt": true, ".pptx": true,
|
||||
}
|
||||
|
||||
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||
|
||||
// FetchFileTexts 从 URL 列表获取文件内容(支持 zip 内文件)
|
||||
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
if len(urls) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
|
||||
}
|
||||
|
||||
for _, rawURL := range urls {
|
||||
url := sanitizeURL(rawURL)
|
||||
if url == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if isBannedExtension(url) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isZipExtension(url) {
|
||||
zipTexts := fetchZipFileTexts(ctx, client, url)
|
||||
for k, v := range zipTexts {
|
||||
result[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
text, err := fetchFileContent(ctx, client, url)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
text = cleanSymbols(text)
|
||||
result[url] = text
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func isZipExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return ext == ".zip"
|
||||
}
|
||||
|
||||
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
zipBytes, err := downloadFile(client, url,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
fileName := file.Name
|
||||
|
||||
if isBannedExtension(fileName) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isZipExtension(fileName) {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(content)
|
||||
if !isReadableContentType(contentType) {
|
||||
continue
|
||||
}
|
||||
|
||||
text := cleanSymbols(string(content))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
key := url + "::" + fileName
|
||||
result[key] = text
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
}
|
||||
|
||||
func isBannedExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return bannedExtensions[ext]
|
||||
}
|
||||
|
||||
func isReadableContentType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
ct := strings.ToLower(contentType)
|
||||
for _, prefix := range allowedMIMEPrefixes {
|
||||
if strings.HasPrefix(ct, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cleanSymbols(text string) string {
|
||||
text = symbolCleaner.ReplaceAllString(text, "")
|
||||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||
text = strings.ReplaceAll(text, "\r", "\n")
|
||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !isReadableContentType(contentType) {
|
||||
return "", fmt.Errorf("unreadable content-type: %s", contentType)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(
|
||||
io.LimitReader(resp.Body,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
func sanitizeURL(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
s = strings.Trim(s, "`\"")
|
||||
return s
|
||||
}
|
||||
|
||||
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
|
||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
// 1. 请求接口获取 SkillUserVO
|
||||
skillResp, err := GetSkillUser(ctx, skillName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
||||
// 2. 下载 zip 文件
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
|
||||
}
|
||||
|
||||
zipBytes, err := downloadFile(client, fullUrl,
|
||||
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 3. 解压 zip 并提取所有 md 文件内容
|
||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(mdContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 4. 拼接所有 md 内容
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
||||
if skillResp.Description != "" {
|
||||
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
||||
}
|
||||
|
||||
for fileName, content := range mdContents {
|
||||
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
|
||||
builder.WriteString(content)
|
||||
builder.WriteString("\n\n---\n\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
|
||||
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
result[file.Name] = strings.TrimSpace(string(content))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
195
service/gateway/gateway_http_service.go
Normal file
195
service/gateway/gateway_http_service.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
// CreateTaskReq 创建任务请求
|
||||
type CreateTaskReq struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State int `json:"state"`
|
||||
OssFile string `json:"oss_file"`
|
||||
FileType string `json:"file_type"`
|
||||
Text string `json:"text"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// CreateGatewayTask 创建网关异步任务
|
||||
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
|
||||
fullURL := "model-gateway/task/createTask"
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req CreateTaskReq
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return req.TaskId, nil
|
||||
}
|
||||
|
||||
type GetModelConfigResp struct {
|
||||
Model *AsynchModel `json:"model"`
|
||||
}
|
||||
|
||||
type AsynchModel struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||||
Form []map[string]any `orm:"form_json" json:"form"`
|
||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody string `orm:"response_body" json:"responseBody"`
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||
CallModel int `orm:"call_model" json:"callModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||
}
|
||||
|
||||
// GetModelConfig 获取模型配置
|
||||
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
|
||||
fullURL := "model-gateway/model/getModel"
|
||||
// 拼接 query 参数
|
||||
var params []string
|
||||
if req.Creator != "" {
|
||||
params = append(params, fmt.Sprintf("creator=%s", req.Creator))
|
||||
}
|
||||
if req.ModelName != "" {
|
||||
params = append(params, fmt.Sprintf("modelName=%s", req.ModelName))
|
||||
}
|
||||
if req.IsChatModel != 0 {
|
||||
params = append(params, fmt.Sprintf("isChatModel=%d", req.IsChatModel))
|
||||
}
|
||||
if len(params) > 0 {
|
||||
fullURL += "?" + strings.Join(params, "&")
|
||||
}
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp GetModelConfigResp
|
||||
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
||||
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
||||
}
|
||||
if resp.Model == nil {
|
||||
return nil, fmt.Errorf("模型不存在")
|
||||
}
|
||||
return resp.Model, nil
|
||||
}
|
||||
|
||||
// GetTaskResultRes 任务结果响应
|
||||
type GetTaskResultRes struct {
|
||||
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
||||
State int `json:"state" dc:"任务状态"`
|
||||
}
|
||||
|
||||
// QueryGatewayTaskState 查询网关任务状态
|
||||
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
|
||||
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req GetTaskResultRes
|
||||
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return req.State, nil
|
||||
}
|
||||
|
||||
// SkillUserVO 技能用户视图对象
|
||||
type SkillUserVO struct {
|
||||
Id int64 `json:"id,string"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
FileName string `json:"fileName"`
|
||||
FileUrl string `json:"fileUrl"`
|
||||
CreatedAt *gtime.Time `json:"createdAt"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt"`
|
||||
ImgAddressPrefix string `json:"imgAddressPrefix"`
|
||||
}
|
||||
|
||||
// GetSkillUser 获取技能用户信息
|
||||
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||||
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp SkillUserVO
|
||||
var req struct{}
|
||||
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// SendCallbackReq 发送回调的请求体
|
||||
type SendCallbackReq struct {
|
||||
TaskId string `json:"taskId"`
|
||||
Status string `json:"status"`
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
}
|
||||
|
||||
// SendCallback 向业务方发送回调
|
||||
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
|
||||
// 1. 检查回调地址
|
||||
if composeTask.CallbackUrl == "" {
|
||||
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
||||
}
|
||||
// 2. 构造请求体
|
||||
req := SendCallbackReq{
|
||||
TaskId: composeTask.TaskId,
|
||||
Status: composeTask.Status,
|
||||
ErrorMsg: composeTask.ErrorMessage,
|
||||
EpicycleId: epicycleId,
|
||||
}
|
||||
// 3. 发送 POST 请求
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp struct{}
|
||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||
composeTask.TaskId, composeTask.CallbackUrl)
|
||||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadFile 从 OSS 下载文件内容
|
||||
func DownloadFile(ossURL string) ([]byte, error) {
|
||||
resp, err := http.Get(ossURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
// CreateTaskReq 创建任务请求
|
||||
type CreateTaskReq struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State int `json:"state"`
|
||||
OssFile string `json:"oss_file"`
|
||||
FileType string `json:"file_type"`
|
||||
Text string `json:"text"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// createGatewayTask 调用 model-gateway 异步任务并同步等待结果
|
||||
func createGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
|
||||
fullURL := "model-gateway/task/createTask"
|
||||
headers := forwardHeaders(ctx)
|
||||
var req CreateTaskReq
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return req.TaskId, nil
|
||||
}
|
||||
|
||||
type GetTaskResultRes struct {
|
||||
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
||||
State int `json:"state" dc:"任务状态"`
|
||||
}
|
||||
|
||||
// queryGatewayTaskState 查询网关任务状态
|
||||
func queryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
|
||||
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
|
||||
headers := forwardHeaders(ctx)
|
||||
var req GetTaskResultRes
|
||||
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return req.State, nil
|
||||
}
|
||||
|
||||
// SkillUserVO 技能用户视图对象
|
||||
type SkillUserVO struct {
|
||||
Id int64 `json:"id,string"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
FileName string `json:"fileName"`
|
||||
FileUrl string `json:"fileUrl"` // html 后缀
|
||||
CreatedAt *gtime.Time `json:"createdAt"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt"`
|
||||
ImgAddressPrefix string `json:"imgAddressPrefix"` // htmml 前缀
|
||||
}
|
||||
|
||||
// GetSkillUser 根据 name 获取技能用户信息
|
||||
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||||
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
|
||||
headers := forwardHeaders(ctx)
|
||||
var resp SkillUserVO
|
||||
var req struct{}
|
||||
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
161
service/prompt/prompt_build_service.go
Normal file
161
service/prompt/prompt_build_service.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"prompts-core/service/gateway"
|
||||
"strings"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
)
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
//1) 构建系统提示词
|
||||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||
ir.AddSystem(systemPrompt)
|
||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||||
ir.AddUser(userPrompt)
|
||||
//2) 检查整体内容是否超出窗口
|
||||
if !checkOverallContent(ir, aiModel) {
|
||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败")
|
||||
}
|
||||
// 如果传了自定义提示词,替换掉协议模板
|
||||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||
protocol.SystemPromptTemplate = customPrompt[0] +
|
||||
"【核心铁律】" +
|
||||
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||||
}
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||
}
|
||||
return map[string]any{
|
||||
"modelName": chatModel.ModelName,
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
"buildType": req.BuildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuildWithRounds 构建提示词
|
||||
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: chatModel.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil || providerProtocol == nil {
|
||||
return ""
|
||||
}
|
||||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
outputJSON, //【输出结构】 %s
|
||||
)
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
return util.CountToken(fullContent, model.TokenConfig)
|
||||
}
|
||||
|
||||
// buildUserPrompt 构建用户提示词
|
||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
|
||||
if prompt != "" {
|
||||
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
|
||||
}
|
||||
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
|
||||
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
|
||||
}
|
||||
if formText := buildUserFormText(req.Form); formText != "" {
|
||||
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
||||
}
|
||||
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
||||
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
||||
}
|
||||
if len(req.Consult) > 0 {
|
||||
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
|
||||
}
|
||||
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
||||
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func buildUserFormText(form []map[string]any) string {
|
||||
if len(form) == 0 {
|
||||
return ""
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, item := range form {
|
||||
for k, v := range item {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
switch val := v.(type) {
|
||||
case []any:
|
||||
for i, elem := range val {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
if m, ok := elem.(map[string]any); ok {
|
||||
for mk, mv := range m {
|
||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||
}
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprint(elem))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
// NodeBuild 节点构建
|
||||
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||
promptTpl := util.GetBuildPrompt(ctx)
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(promptTpl,
|
||||
gjson.New(req.Form).MustToJsonString(),
|
||||
gjson.New(req.UserForm).MustToJsonString(),
|
||||
)
|
||||
}
|
||||
331
service/prompt/prompt_compose_service.go
Normal file
331
service/prompt/prompt_compose_service.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/service/gateway"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ComposeMessages 核心拼接提示词主流程
|
||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 获取模型信息
|
||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2) 校验用户表单
|
||||
if err = validateUserForm(req, aiModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handleBuild(ctx, req, chatModel, aiModel)
|
||||
}
|
||||
|
||||
// GetModelMessage 获取模型信息
|
||||
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
IsChatModel: 1,
|
||||
})
|
||||
if err != nil || chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
|
||||
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil || aiModel == nil {
|
||||
return nil, nil, errors.New("需要构建的模型不存在")
|
||||
}
|
||||
|
||||
return chatModel, aiModel, nil
|
||||
}
|
||||
|
||||
// validateUserForm 校验用户表单
|
||||
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
|
||||
if len(req.UserForm) == 0 {
|
||||
return nil
|
||||
}
|
||||
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("校验用户表单失败: %w", err)
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
||||
exceedTokens, availableWindow)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleBuild 通用构建处理
|
||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 处理表单分批
|
||||
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 构建推理请求
|
||||
ir := NewPromptIR()
|
||||
var taskReq map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||
case public.BuildTypeNode:
|
||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
case public.BuildTypeStruct:
|
||||
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
|
||||
default:
|
||||
return nil, errors.New("不支持的构建类型")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 3) 调用网关创建任务
|
||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建网关任务失败: %w", err)
|
||||
}
|
||||
if taskID == "" {
|
||||
return nil, errors.New("网关未返回taskId")
|
||||
}
|
||||
|
||||
// 4) 保存任务记录
|
||||
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
ModelName: req.ModelName,
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: gconv.Map(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
|
||||
}
|
||||
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||
|
||||
// 1) 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 读取 OSS 文件内容
|
||||
var ossContent []byte
|
||||
if req.OssFile != "" {
|
||||
ossContent, err = gateway.DownloadFile(req.OssFile)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 解析 OSS 内容为消息
|
||||
var messages map[string]any
|
||||
if len(ossContent) > 0 {
|
||||
messages, _ = gjson.New(ossContent).Map(), nil
|
||||
}
|
||||
|
||||
// 4) 处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailed(ctx, req, composeTask, messages)
|
||||
}
|
||||
|
||||
// 5) 处理成功
|
||||
if req.State == 2 {
|
||||
return handleCallbackSuccess(ctx, req, composeTask, messages)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCallbackFailed 处理回调失败
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultJson: messages,
|
||||
})
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusFailed
|
||||
composeTask.ErrorMessage = req.ErrorMsg
|
||||
_ = gateway.SendCallback(ctx, composeTask, 0)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
ModelName: composeTask.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 获取协议配置
|
||||
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: model.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 3) 获取历史消息 + 保存当前轮
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
var history []dto.FlatMessage
|
||||
var epicycleId int64
|
||||
|
||||
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||
// 3.1 获取历史
|
||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
})
|
||||
if h != nil {
|
||||
history = h.Messages
|
||||
}
|
||||
|
||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 5) 注入历史
|
||||
if len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
}
|
||||
|
||||
// 6) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultJson: messages,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 8) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.ResultJson = messages
|
||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InjectHistory 插入历史会话
|
||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||
if protocol == nil || len(history) == 0 {
|
||||
return roundsData
|
||||
}
|
||||
// 1) 提取第一轮的 messages
|
||||
rounds := roundsData["rounds"].([]any)
|
||||
firstRound := rounds[0].(map[string]any)
|
||||
original := firstRound["messages"].([]any)
|
||||
|
||||
// 2) 按 merge_order 拼接
|
||||
result := make([]any, 0, len(original)+len(history))
|
||||
|
||||
for _, part := range protocol.MergeOrder {
|
||||
switch part {
|
||||
case "system":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "system" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
case "history":
|
||||
if gconv.Bool(protocol.Capabilities["support_history"]) {
|
||||
for _, msg := range history {
|
||||
result = append(result, map[string]any{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content, // 纯字符串,不转换
|
||||
})
|
||||
}
|
||||
}
|
||||
case "user":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "user" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 角色映射
|
||||
if len(protocol.RoleMapping) > 0 {
|
||||
for _, m := range result {
|
||||
msg := m.(map[string]any)
|
||||
role := gconv.String(msg["role"])
|
||||
if mapped, ok := protocol.RoleMapping[role]; ok {
|
||||
msg["role"] = mapped
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 直接修改原对象
|
||||
firstRound["messages"] = result
|
||||
return roundsData
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: record.ResultJson,
|
||||
}, nil
|
||||
}
|
||||
295
service/prompt/prompt_files_handle_service.go
Normal file
295
service/prompt/prompt_files_handle_service.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"prompts-core/model/dto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/service/gateway"
|
||||
)
|
||||
|
||||
const (
|
||||
bytesPerKB = 1024
|
||||
bytesPerMB = 1024 * 1024
|
||||
)
|
||||
|
||||
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
|
||||
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
|
||||
urls := make([]string, 0, len(consult))
|
||||
for _, item := range consult {
|
||||
if item.Url != "" {
|
||||
urls = append(urls, item.Url)
|
||||
}
|
||||
}
|
||||
return FetchFileTextsAsString(ctx, urls)
|
||||
}
|
||||
|
||||
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
|
||||
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
|
||||
if len(urls) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
||||
var builder strings.Builder
|
||||
|
||||
for _, rawURL := range urls {
|
||||
url := util.SanitizeURL(rawURL)
|
||||
if url == "" || util.IsBannedExtension(url) {
|
||||
continue
|
||||
}
|
||||
|
||||
if util.IsZipExtension(url) {
|
||||
for _, text := range fetchZipFileTexts(ctx, client, url) {
|
||||
builder.WriteString(text)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
||||
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// fetchAndCleanFileContent 获取并清理文件内容
|
||||
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
|
||||
text, err := fetchFileContent(ctx, client, url)
|
||||
if err != nil || text == "" {
|
||||
return ""
|
||||
}
|
||||
return util.CleanSymbols(text)
|
||||
}
|
||||
|
||||
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
|
||||
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||
zipBytes, err := downloadFile(client, url, maxSize)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
|
||||
for _, file := range reader.File {
|
||||
if shouldSkipZipEntry(file.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
|
||||
result[url+"::"+file.Name] = text
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
|
||||
func shouldSkipZipEntry(fileName string) bool {
|
||||
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
|
||||
}
|
||||
|
||||
// extractZipEntryContent 提取 zip 条目内容
|
||||
func extractZipEntryContent(file *zip.File, maxSize int64) string {
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if !util.IsReadableContentType(http.DetectContentType(content)) {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := util.CleanSymbols(string(content))
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// downloadFile 下载文件,限制最大大小
|
||||
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("执行请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// fetchFileContent 获取单个文本文件内容
|
||||
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("执行请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !util.IsReadableContentType(contentType) {
|
||||
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
|
||||
}
|
||||
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
if skillName == "" {
|
||||
return ""
|
||||
}
|
||||
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
||||
|
||||
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
|
||||
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
|
||||
|
||||
zipBytes, err := downloadFile(client, fullUrl, maxSize)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[SkillMd] 下载失败 url=%s err=%v", fullUrl, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||
if err != nil || len(mdContents) == 0 {
|
||||
g.Log().Warningf(ctx, "[SkillMd] 提取md失败 count=%d err=%v", len(mdContents), err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return buildSkillMarkdown(skillResp, mdContents)
|
||||
}
|
||||
|
||||
// buildSkillMarkdown 构建技能 Markdown 内容
|
||||
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
|
||||
var builder strings.Builder
|
||||
|
||||
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
||||
if skillResp.Description != "" {
|
||||
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
||||
}
|
||||
|
||||
for fileName, content := range mdContents {
|
||||
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
|
||||
builder.WriteString(content)
|
||||
builder.WriteString("\n\n---\n\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
|
||||
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
|
||||
result[file.Name] = content
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isMarkdownFile 判断是否为 Markdown 文件
|
||||
func isMarkdownFile(fileName string) bool {
|
||||
return strings.HasSuffix(strings.ToLower(fileName), ".md")
|
||||
}
|
||||
|
||||
// readMarkdownFileContent 读取 Markdown 文件内容
|
||||
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
|
||||
// createHTTPClient 创建 HTTP 客户端
|
||||
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
285
service/prompt/prompt_ir_service.go
Normal file
285
service/prompt/prompt_ir_service.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"prompts-core/service/gateway"
|
||||
"strings"
|
||||
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// IR 统一 Prompt 中间表示
|
||||
type IR struct {
|
||||
System []Segment `json:"system"`
|
||||
History []Segment `json:"history"`
|
||||
User []Segment `json:"user"`
|
||||
}
|
||||
|
||||
// Segment 消息片段
|
||||
type Segment struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||
type ProviderProtocol struct {
|
||||
TargetField string `json:"target_field"`
|
||||
MergeOrder []string `json:"merge_order"`
|
||||
RoleMapping map[string]string `json:"role_mapping"`
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
type ContentMapping struct {
|
||||
Type string `json:"type"`
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
// NewPromptIR 创建空 PromptIR
|
||||
func NewPromptIR() *IR {
|
||||
return &IR{
|
||||
System: make([]Segment, 0),
|
||||
History: make([]Segment, 0),
|
||||
User: make([]Segment, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||
func (ir *IR) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString("System: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Role)
|
||||
builder.WriteString(": ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString("User: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||
func (ir *IR) GetTotalContent() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// AddSystem 添加系统提示
|
||||
func (ir *IR) AddSystem(content string) *IR {
|
||||
if content != "" {
|
||||
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
||||
}
|
||||
return ir
|
||||
}
|
||||
|
||||
// AddUser 添加用户消息
|
||||
func (ir *IR) AddUser(content string) *IR {
|
||||
if content != "" {
|
||||
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
||||
}
|
||||
return ir
|
||||
}
|
||||
|
||||
// AddHistory 添加历史消息
|
||||
func (ir *IR) AddHistory(role, content string) *IR {
|
||||
if content != "" {
|
||||
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
||||
}
|
||||
return ir
|
||||
}
|
||||
|
||||
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
||||
func (ir *IR) ToMessages() []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, seg := range ir.System {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": seg.Role,
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// GetProtocolByProvider 根据 provider_name 获取协议配置
|
||||
func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) {
|
||||
entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: providerName,
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil || entity == nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseProtocol(entity), nil
|
||||
}
|
||||
|
||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
return &ProviderProtocol{
|
||||
TargetField: e.TargetField,
|
||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||
MergeOrder: e.MergeOrder,
|
||||
RoleMapping: gconv.MapStrStr(e.RoleMapping),
|
||||
ContentMapping: ContentMapping{
|
||||
Type: gconv.String(e.ContentMapping["type"]),
|
||||
Field: gconv.String(e.ContentMapping["field"]),
|
||||
},
|
||||
RequestTemplate: e.RequestTemplate,
|
||||
Capabilities: e.Capabilities,
|
||||
}
|
||||
}
|
||||
|
||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
if ir == nil || p == nil {
|
||||
return nil, fmt.Errorf("ir and protocol are required")
|
||||
}
|
||||
messages := mergeByOrder(ir, p.MergeOrder)
|
||||
messages = mapRoles(messages, p.RoleMapping)
|
||||
messages = mapContent(messages, p.ContentMapping)
|
||||
|
||||
return buildRequest(messages, p, chatModel), nil
|
||||
}
|
||||
|
||||
// mergeByOrder 按协议配置顺序拼接消息
|
||||
func mergeByOrder(ir *IR, order []string) []map[string]any {
|
||||
roleMap := map[string][]Segment{
|
||||
"system": ir.System,
|
||||
"history": ir.History,
|
||||
"user": ir.User,
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
for _, part := range order {
|
||||
for _, seg := range roleMap[part] {
|
||||
msg := map[string]any{"content": seg.Content}
|
||||
if part == "history" {
|
||||
msg["role"] = seg.Role
|
||||
} else {
|
||||
msg["role"] = part
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// mapRoles 角色映射
|
||||
func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any {
|
||||
if len(mapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
role, ok := msg["role"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if mapped, exists := mapping[role]; exists {
|
||||
messages[i]["role"] = mapped
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
if cm.Field == "" || cm.Field == "content" {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
if content, ok := msg["content"]; ok {
|
||||
delete(msg, "content")
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
|
||||
default:
|
||||
messages[i][cm.Field] = content
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||
if len(p.RequestTemplate) > 0 {
|
||||
return renderTemplate(p, messages, chatModel)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
p.TargetField: messages,
|
||||
}
|
||||
}
|
||||
|
||||
// renderTemplate 模板渲染
|
||||
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
result := make(map[string]any, len(p.RequestTemplate)+1)
|
||||
for k, v := range p.RequestTemplate {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
if chatModel != nil {
|
||||
result["model"] = chatModel.ModelName
|
||||
}
|
||||
result["messages"] = messages
|
||||
|
||||
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||
result["max_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
135
service/prompt/prompt_user_form_batches.go
Normal file
135
service/prompt/prompt_user_form_batches.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"prompts-core/service/gateway"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
)
|
||||
|
||||
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
||||
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
||||
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
||||
return req, 1, nil
|
||||
}
|
||||
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
|
||||
|
||||
if len(batches) <= 1 {
|
||||
return req, 1, nil
|
||||
}
|
||||
|
||||
newUserForm := buildBatchedUserForm(batches)
|
||||
|
||||
newReq := *req
|
||||
newReq.UserForm = newUserForm
|
||||
|
||||
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
|
||||
len(req.UserForm), len(batches))
|
||||
|
||||
return &newReq, len(batches), nil
|
||||
}
|
||||
|
||||
// buildBatchedUserForm 构建分批后的用户表单
|
||||
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
|
||||
newUserForm := make([]map[string]any, 0, len(batches))
|
||||
|
||||
for i, batch := range batches {
|
||||
combinedText := combineBatchText(batch)
|
||||
newUserForm = append(newUserForm, map[string]any{
|
||||
"batch_index": i + 1,
|
||||
"total_batches": len(batches),
|
||||
"text": combinedText,
|
||||
"item_count": len(batch),
|
||||
})
|
||||
}
|
||||
|
||||
return newUserForm
|
||||
}
|
||||
|
||||
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
|
||||
func combineBatchText(batch []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
|
||||
for j, item := range batch {
|
||||
itemText := getItemText(item)
|
||||
if itemText == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if j > 0 {
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
builder.WriteString(itemText)
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
|
||||
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
|
||||
if len(userForm) == 0 {
|
||||
return [][]map[string]any{}
|
||||
}
|
||||
|
||||
batches := make([][]map[string]any, 0)
|
||||
currentBatch := make([]map[string]any, 0)
|
||||
currentTokens := 0
|
||||
|
||||
for i, item := range userForm {
|
||||
itemText := getItemText(item)
|
||||
itemTokens := util.CalculateTokens(itemText, tokenConfig)
|
||||
|
||||
// 单个元素超过窗口,单独成一批
|
||||
if itemTokens > maxTokens {
|
||||
if len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
currentBatch = make([]map[string]any, 0)
|
||||
currentTokens = 0
|
||||
}
|
||||
batches = append(batches, []map[string]any{item})
|
||||
continue
|
||||
}
|
||||
|
||||
// 判断是否需要新开一批
|
||||
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
currentBatch = make([]map[string]any, 0)
|
||||
currentTokens = 0
|
||||
}
|
||||
|
||||
currentBatch = append(currentBatch, item)
|
||||
currentTokens += itemTokens
|
||||
|
||||
// 最后一批
|
||||
if i == len(userForm)-1 && len(currentBatch) > 0 {
|
||||
batches = append(batches, currentBatch)
|
||||
}
|
||||
}
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
|
||||
func getItemText(item map[string]any) string {
|
||||
if len(item) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for key, value := range item {
|
||||
// 跳过分批时添加的元数据字段
|
||||
if key == "batch_index" || key == "total_batches" || key == "item_count" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%v", value))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
var Prompt = &promptService{}
|
||||
|
||||
type promptService struct{}
|
||||
|
||||
func (s *promptService) Create(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
|
||||
// promptInfo 兜底校验:必须可序列化为 JSON
|
||||
if req.PromptInfo == nil {
|
||||
return nil, errors.New("promptInfo不能为空")
|
||||
}
|
||||
if _, err := json.Marshal(req.PromptInfo); err != nil {
|
||||
return nil, errors.New("promptInfo不是合法JSON")
|
||||
}
|
||||
if req.ResponseJsonSchema == nil {
|
||||
return nil, errors.New("responseJsonSchema不能为空")
|
||||
}
|
||||
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
|
||||
return nil, errors.New("responseJsonSchema不是合法JSON")
|
||||
}
|
||||
|
||||
m := &entity.PromptConfig{
|
||||
ModelTypeId: req.ModelTypeId,
|
||||
ModelType: req.ModelType,
|
||||
PromptInfo: req.PromptInfo,
|
||||
ResponseJsonSchema: req.ResponseJsonSchema,
|
||||
Enabled: 1,
|
||||
Version: req.Version,
|
||||
}
|
||||
|
||||
id, err := dao.Prompt.Insert(ctx, m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreatePromptRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *promptService) Update(ctx context.Context, req *dto.UpdatePromptReq) error {
|
||||
data := map[string]any{}
|
||||
if req.ModelTypeId != nil && *req.ModelTypeId > 0 {
|
||||
data[entity.PromptConfigCol.ModelTypeId] = *req.ModelTypeId
|
||||
}
|
||||
if req.ModelType != nil && *req.ModelType != "" {
|
||||
data[entity.PromptConfigCol.ModelType] = *req.ModelType
|
||||
}
|
||||
if req.PromptInfo != nil {
|
||||
if _, err := json.Marshal(req.PromptInfo); err != nil {
|
||||
return errors.New("promptInfo不是合法JSON")
|
||||
}
|
||||
data[entity.PromptConfigCol.PromptInfo] = req.PromptInfo
|
||||
}
|
||||
if req.ResponseJsonSchema != nil {
|
||||
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
|
||||
return errors.New("responseJsonSchema不是合法JSON")
|
||||
}
|
||||
data[entity.PromptConfigCol.ResponseJsonSchema] = req.ResponseJsonSchema
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
data[entity.PromptConfigCol.Enabled] = *req.Enabled
|
||||
}
|
||||
if req.Version != nil {
|
||||
data[entity.PromptConfigCol.Version] = *req.Version
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.Prompt.UpdateByID(ctx, req.ID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *promptService) Delete(ctx context.Context, id int64) error {
|
||||
_, err := dao.Prompt.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *promptService) Get(ctx context.Context, id int64) (*entity.PromptConfig, error) {
|
||||
return dao.Prompt.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *promptService) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
|
||||
return dao.Prompt.List(ctx, pageNum, pageSize, modelTypeID, modelTypeLike)
|
||||
}
|
||||
151
service/session/prompt_session_redis_service.go
Normal file
151
service/session/prompt_session_redis_service.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
const (
|
||||
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId}
|
||||
RedisKeySessionHistory = "session:history:%d:%s:%s"
|
||||
)
|
||||
|
||||
// formatRedisKey 格式化 Redis key
|
||||
func formatRedisKey(tenantID uint64, sessionID, nodeID string) string {
|
||||
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 写操作
|
||||
// ============================================
|
||||
|
||||
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||||
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||||
|
||||
b, err := json.Marshal(round)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
score := float64(time.Now().UnixMilli())
|
||||
|
||||
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||||
return fmt.Errorf("ZADD失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||||
return fmt.Errorf("裁剪失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSessionHistory 删除整个 session 下所有 node 的缓存
|
||||
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||||
pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*")
|
||||
keys, err := g.Redis().Do(ctx, "KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys.Strings() {
|
||||
_, _ = g.Redis().Do(ctx, "DEL", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRedisMessages 批量删除指定 node 下的消息
|
||||
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
for _, msgID := range msgIDs {
|
||||
cursor := "0"
|
||||
for {
|
||||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
|
||||
break
|
||||
}
|
||||
parts := result.Strings()
|
||||
if len(parts) < 2 {
|
||||
break
|
||||
}
|
||||
cursor = parts[0]
|
||||
for _, member := range parts[1:] {
|
||||
_, _ = g.Redis().Do(ctx, "ZREM", key, member)
|
||||
}
|
||||
if cursor == "0" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 读操作
|
||||
// ============================================
|
||||
|
||||
// GetFromRedis 从 Redis ZSET 获取会话历史
|
||||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||||
}
|
||||
|
||||
if result == nil || result.IsNil() {
|
||||
return []dto.HistoryRound{}, nil
|
||||
}
|
||||
|
||||
return parseRounds(result.Strings()), nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 解析
|
||||
// ============================================
|
||||
|
||||
func parseRounds(members []string) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(members))
|
||||
for _, member := range members {
|
||||
var round dto.HistoryRound
|
||||
if err := json.Unmarshal([]byte(member), &round); err != nil {
|
||||
continue
|
||||
}
|
||||
if round.User != nil || round.Assistant != nil {
|
||||
rounds = append(rounds, round)
|
||||
}
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
|
||||
var messages []dto.FlatMessage
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
|
||||
messages = append(messages, dto.FlatMessage{
|
||||
Role: gconv.String(rounds[i].User["role"]),
|
||||
Content: gconv.String(rounds[i].User["content"]),
|
||||
})
|
||||
}
|
||||
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
|
||||
messages = append(messages, dto.FlatMessage{
|
||||
Role: gconv.String(rounds[i].Assistant["role"]),
|
||||
Content: gconv.String(rounds[i].Assistant["content"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
191
service/session/prompt_session_service.go
Normal file
191
service/session/prompt_session_service.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 回调存储
|
||||
// ============================================
|
||||
|
||||
// Callback 会话回调
|
||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
req.Messages["role"] = "assistant"
|
||||
// 1) 更新 DB
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: req.Messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 查询完整记录
|
||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
})
|
||||
if err != nil || session == nil {
|
||||
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||
}
|
||||
|
||||
// 3) entity → HistoryRound → 写入 Redis
|
||||
round := entityToHistoryRound(session)
|
||||
round.Assistant = req.Messages
|
||||
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil {
|
||||
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
||||
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 场景1:前端历史列表(按 creator)
|
||||
// ============================================
|
||||
|
||||
// GetHistoryList 获取历史列表
|
||||
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
}, req.Page, req.Size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
|
||||
}
|
||||
rounds := sessionsToHistoryRounds(sessions)
|
||||
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 场景2:提示词拼接(按 sessionId + nodeId)
|
||||
// ============================================
|
||||
|
||||
// GetHistoryMessages 获取历史消息(Redis → DB → 异步回种)
|
||||
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 1) Redis
|
||||
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 {
|
||||
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
|
||||
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||
}
|
||||
|
||||
// 2) DB
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
SessionId: req.SessionId,
|
||||
NodeId: req.NodeId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
|
||||
}
|
||||
|
||||
// 3) 转换 + 异步回种
|
||||
rounds := sessionsToHistoryRounds(sessions)
|
||||
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds)
|
||||
|
||||
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 删除
|
||||
// ============================================
|
||||
|
||||
// DeleteMessages 删除消息
|
||||
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
|
||||
if len(req.MsgIds) == 0 {
|
||||
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
|
||||
}
|
||||
|
||||
user, _ := utils.GetUserInfo(ctx)
|
||||
|
||||
// 1) 批量查询
|
||||
sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||
|
||||
// 2) 批量删 DB
|
||||
_, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||
|
||||
// 3) 按 nodeId 分组删 Redis
|
||||
for _, s := range sessions {
|
||||
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds)
|
||||
}
|
||||
return &dto.DeleteMessagesRes{Ok: true}, nil
|
||||
}
|
||||
|
||||
// DeleteSession 删除整个会话
|
||||
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
||||
// 1) 删 DB
|
||||
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("DB删除失败: %w", err)
|
||||
}
|
||||
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2) 删 Redis
|
||||
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
|
||||
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
||||
}
|
||||
|
||||
return &dto.DeleteSessionRes{Ok: true}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 转换方法(entity ↔ dto,集中管理)
|
||||
// ============================================
|
||||
|
||||
// entityToHistoryRound entity → HistoryRound
|
||||
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||
return &dto.HistoryRound{
|
||||
Id: s.Id,
|
||||
SessionId: s.SessionId,
|
||||
NodeId: s.NodeId,
|
||||
CreatedAt: gconv.String(s.CreatedAt),
|
||||
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||
User: s.RequestContent,
|
||||
Assistant: s.ResponseContent,
|
||||
}
|
||||
}
|
||||
|
||||
// sessionsToHistoryRounds 批量转换
|
||||
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
rounds = append(rounds, *entityToHistoryRound(s))
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
// asyncCacheToRedis 异步缓存到 Redis
|
||||
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) {
|
||||
for i := range rounds {
|
||||
if rounds[i].User != nil || rounds[i].Assistant != nil {
|
||||
_ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ==================== Redis 操作 ====================
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
||||
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
expireTime := time.Duration(expireSeconds) * time.Second
|
||||
|
||||
data := map[string]any{
|
||||
"sessionId": sessionId,
|
||||
"requestContent": requestMessages,
|
||||
"responseContent": responseMessages,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入Redis失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
|
||||
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := fmt.Sprintf("chat:session:%s", sessionId)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
||||
}
|
||||
|
||||
if result == nil || result.IsNil() {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
var sessions []map[string]any
|
||||
values := result.Strings()
|
||||
for _, str := range values {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, data)
|
||||
}
|
||||
|
||||
// 反转(Redis 最新在前 → 时间正序)
|
||||
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
|
||||
sessions[i], sessions[j] = sessions[j], sessions[i]
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||
func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
historyData, err := s.getFromRedis(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
||||
}
|
||||
|
||||
if len(historyData) == 0 {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
for _, round := range historyData {
|
||||
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
|
||||
for _, m := range reqMsgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
|
||||
for _, m := range respMsgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Session = &sessionService{}
|
||||
|
||||
type sessionService struct{}
|
||||
|
||||
func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
|
||||
// 1. 解析AI返回的文本
|
||||
result, err := parseOutput(req.Text)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 更新数据库
|
||||
result["role"] = "assistant"
|
||||
_, err = dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: result,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. 获取当前轮次完整数据
|
||||
session, err := dao.ComposeSession.GetById(ctx, req.EpicycleId)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 转换 json 并存入 Redis
|
||||
requestMessages := convertToMessages(session.RequestContent)
|
||||
responseMessages := convertToMessages(session.ResponseContent)
|
||||
|
||||
if err = s.saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
||||
session.SessionId, session.Id, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
||||
return &beans.ResponseEmpty{}, nil
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
|
||||
// 1. 先从 Redis 拿
|
||||
redisHistory, err := s.GetSessionHistoryForInference(ctx, sessionId)
|
||||
if err == nil && len(redisHistory) > 0 {
|
||||
return redisHistory, nil
|
||||
}
|
||||
|
||||
// 2. Redis 没有 → fallback DB
|
||||
sessions, err := dao.ComposeSession.GetListBySessionId(ctx, sessionId, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
|
||||
for _, session := range sessions {
|
||||
// request
|
||||
reqMsgs := convertToMessages(session.RequestContent)
|
||||
for _, m := range reqMsgs {
|
||||
role := gconv.String(m["role"])
|
||||
if role == "user" || role == "assistant" {
|
||||
messages = append(messages, m)
|
||||
}
|
||||
}
|
||||
|
||||
// response
|
||||
respMsgs := convertToMessages(session.ResponseContent)
|
||||
for _, m := range respMsgs {
|
||||
if m["role"] == nil {
|
||||
m["role"] = "assistant"
|
||||
}
|
||||
messages = append(messages, m)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 回写 Redis
|
||||
for _, session := range sessions {
|
||||
reqMsgs := convertToMessages(session.RequestContent)
|
||||
respMsgs := convertToMessages(session.ResponseContent)
|
||||
for i := range respMsgs {
|
||||
if respMsgs[i]["role"] == nil {
|
||||
respMsgs[i]["role"] = "assistant"
|
||||
}
|
||||
}
|
||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||
_ = s.saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
// utils 工具函数
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// json 相关处理
|
||||
// ============================================
|
||||
// parseOutput 解析模型输出为 JSON 格式
|
||||
func parseOutput(text string) (map[string]any, error) {
|
||||
j, err := gjson.LoadJson([]byte(text))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
return j.Map(), nil
|
||||
}
|
||||
|
||||
func convertToMessages(raw any) []map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// 1. 如果有 messages
|
||||
if j.Contains("messages") {
|
||||
return gconv.Maps(j.Get("messages").Array())
|
||||
}
|
||||
// 2. 否则当成单条 message
|
||||
return []map[string]any{
|
||||
j.Map(),
|
||||
}
|
||||
}
|
||||
|
||||
// isMessageValid 校验推理结果是否合法
|
||||
func isMessageValid(message map[string]any) bool {
|
||||
if message == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formToJSON(form map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
}
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func mustMarshal(v any) string {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
219
update.sql
219
update.sql
@@ -1,117 +1,134 @@
|
||||
-- prompts-core 核心表(pgsql)
|
||||
-- 说明:字段风格尽量与参考项目一致(tenant/creator/updater/created_at/updated_at/deleted_at)
|
||||
|
||||
-- prompts_model_prompt 模型提示词配置表
|
||||
CREATE TABLE IF NOT EXISTS prompts_model_prompt (
|
||||
-- 基础字段(与 common/db/gfdb 的 Hook 约定保持一致)
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
|
||||
creator VARCHAR(64) NOT NULL, -- 创建人
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
|
||||
updater VARCHAR(64) NOT NULL, -- 更新人
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
|
||||
deleted_at TIMESTAMP(6), -- 删除时间(软删)
|
||||
|
||||
-- 业务字段(按你当前的最小字段集)
|
||||
model_type_id INT NOT NULL DEFAULT 0, -- 模型分类ID
|
||||
model_type VARCHAR(64) NOT NULL, -- 模型类别
|
||||
prompt_info JSONB NOT NULL DEFAULT '{}'::jsonb, -- 提示词信息(JSON)
|
||||
response_json_schema JSONB NOT NULL DEFAULT '{}'::jsonb, -- 模型返回表单 JSON 格式约束
|
||||
enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用:1启用/0禁用
|
||||
version VARCHAR(64) NOT NULL DEFAULT '' -- 版本号(预留)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_tenant_id ON prompts_model_prompt(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type_id ON prompts_model_prompt(model_type_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type ON prompts_model_prompt(model_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_enabled ON prompts_model_prompt(enabled);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_deleted_at ON prompts_model_prompt(deleted_at);
|
||||
|
||||
COMMENT ON TABLE prompts_model_prompt IS '模型提示词配置表';
|
||||
COMMENT ON COLUMN prompts_model_prompt.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN prompts_model_prompt.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN prompts_model_prompt.creator IS '创建人';
|
||||
COMMENT ON COLUMN prompts_model_prompt.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN prompts_model_prompt.updater IS '更新人';
|
||||
COMMENT ON COLUMN prompts_model_prompt.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN prompts_model_prompt.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN prompts_model_prompt.model_type_id IS '模型分类ID';
|
||||
COMMENT ON COLUMN prompts_model_prompt.model_type IS '模型类别';
|
||||
COMMENT ON COLUMN prompts_model_prompt.prompt_info IS '提示词信息(JSON)';
|
||||
COMMENT ON COLUMN prompts_model_prompt.response_json_schema IS '模型返回表单 JSON 格式约束';
|
||||
COMMENT ON COLUMN prompts_model_prompt.enabled IS '是否启用:1启用/0禁用';
|
||||
COMMENT ON COLUMN prompts_model_prompt.version IS '版本号(预留)';
|
||||
|
||||
-- prompts_compose_task 拼接提示词任务记录表
|
||||
-- prompts_compose_task 提示词任务记录表
|
||||
CREATE TABLE IF NOT EXISTS prompts_compose_task (
|
||||
id BIGINT PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
id BIGINT PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
|
||||
task_id VARCHAR(64) NOT NULL,
|
||||
model_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
skill_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
gateway_state INT NOT NULL DEFAULT 0,
|
||||
limit_words INT NOT NULL DEFAULT 0,
|
||||
request_payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
result_text TEXT NOT NULL DEFAULT '',
|
||||
messages JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
oss_file VARCHAR(1024) NOT NULL DEFAULT '',
|
||||
file_type VARCHAR(64) NOT NULL DEFAULT ''
|
||||
);
|
||||
task_id VARCHAR(64) NOT NULL,
|
||||
model_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
skill_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
build_type INT NOT NULL DEFAULT 0,
|
||||
callback_url VARCHAR(512) NOT NULL DEFAULT '',
|
||||
gateway_state INT NOT NULL DEFAULT 0,
|
||||
request_payload JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
result_text TEXT NOT NULL DEFAULT '',
|
||||
messages JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
oss_file VARCHAR(1024) NOT NULL DEFAULT '',
|
||||
file_type VARCHAR(64) NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task(deleted_at);
|
||||
|
||||
COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表';
|
||||
COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID';
|
||||
COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称';
|
||||
COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称';
|
||||
COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态:0排队/1执行/2成功/3失败/4已下载';
|
||||
COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数';
|
||||
-- 注释
|
||||
COMMENT ON TABLE prompts_compose_task IS '提示词任务记录表';
|
||||
COMMENT ON COLUMN prompts_compose_task.id IS '主键ID';
|
||||
COMMENT ON COLUMN prompts_compose_task.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN prompts_compose_task.creator IS '创建人';
|
||||
COMMENT ON COLUMN prompts_compose_task.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN prompts_compose_task.updater IS '更新人';
|
||||
COMMENT ON COLUMN prompts_compose_task.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN prompts_compose_task.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID';
|
||||
COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称';
|
||||
COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称';
|
||||
COMMENT ON COLUMN prompts_compose_task.build_type IS '构建类型:0默认/1提示词构建/2节点构建';
|
||||
COMMENT ON COLUMN prompts_compose_task.callback_url IS '回调地址';
|
||||
COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态:0排队/1执行/2成功/3失败/4已下载';
|
||||
COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容';
|
||||
COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果';
|
||||
COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages';
|
||||
COMMENT ON COLUMN prompts_compose_task.status IS '业务状态:pending/success/failed';
|
||||
COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息';
|
||||
COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址';
|
||||
COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型';
|
||||
COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果';
|
||||
COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages';
|
||||
COMMENT ON COLUMN prompts_compose_task.status IS '业务状态:pending/success/failed';
|
||||
COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息';
|
||||
COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址';
|
||||
COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型';
|
||||
|
||||
|
||||
|
||||
-- prompts_compose_session 提示词历史会话表
|
||||
CREATE TABLE IF NOT EXISTS prompts_compose_session (
|
||||
id BIGINT PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
id BIGINT NOT NULL,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
|
||||
session_id VARCHAR(64) NOT NULL,
|
||||
request_content JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
response_content JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
remark VARCHAR(500) NOT NULL DEFAULT ''
|
||||
session_id VARCHAR(64) NOT NULL,
|
||||
request_content JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
response_content JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
remark VARCHAR(500) NOT NULL DEFAULT ''
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_session_id ON prompts_compose_session(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_deleted_at ON prompts_compose_session(deleted_at);
|
||||
|
||||
COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表';
|
||||
COMMENT ON COLUMN prompts_compose_session.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN prompts_compose_session.creator IS '创建人';
|
||||
COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN prompts_compose_session.updater IS '更新人';
|
||||
COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID';
|
||||
COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容(JSON格式)';
|
||||
-- 注释
|
||||
COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表';
|
||||
COMMENT ON COLUMN prompts_compose_session.id IS '主键ID';
|
||||
COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN prompts_compose_session.creator IS '创建人';
|
||||
COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN prompts_compose_session.updater IS '更新人';
|
||||
COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID';
|
||||
COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容(JSON格式)';
|
||||
COMMENT ON COLUMN prompts_compose_session.response_content IS '返回内容(JSON格式)';
|
||||
COMMENT ON COLUMN prompts_compose_session.remark IS '备注';
|
||||
COMMENT ON COLUMN prompts_compose_session.remark IS '备注';
|
||||
|
||||
|
||||
|
||||
-- prompts_provider_protocol 模型协议映射配置表
|
||||
CREATE TABLE IF NOT EXISTS prompts_provider_protocol (
|
||||
id BIGINT PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
|
||||
provider_name VARCHAR(64) NOT NULL DEFAULT '',
|
||||
target_field VARCHAR(64) NOT NULL DEFAULT '',
|
||||
merge_order JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
role_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
content_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
request_template JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
system_prompt_template TEXT NOT NULL DEFAULT '',
|
||||
user_prompt_template TEXT NOT NULL DEFAULT '',
|
||||
status INT NOT NULL DEFAULT 1,
|
||||
remark VARCHAR(500) NOT NULL DEFAULT ''
|
||||
);
|
||||
-- 索引
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_provider_name ON prompts_provider_protocol(provider_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_status ON prompts_provider_protocol(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_deleted_at ON prompts_provider_protocol(deleted_at);
|
||||
-- 注释
|
||||
COMMENT ON TABLE prompts_provider_protocol IS '模型协议映射配置表';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.id IS '主键ID';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.creator IS '创建人';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.updater IS '更新人';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.provider_name IS '运营商名称(openai/deepseek/qwen/anthropic/gemini等)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.target_field IS '目标字段(messages/contents/prompt)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.merge_order IS 'Prompt IR 拼接顺序(system/history/user)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.role_mapping IS '角色映射(system/user/assistant -> provider role)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.content_mapping IS '内容字段映射(content/parts.text等)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.capabilities IS '协议能力配置(system/history/tools/stream等支持情况)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.request_template IS '请求模板(JSON结构模板)';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.system_prompt_template IS '系统提示词模板';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.status IS '状态:1启用/0禁用';
|
||||
COMMENT ON COLUMN prompts_provider_protocol.remark IS '备注';
|
||||
Reference in New Issue
Block a user