Compare commits

6 Commits

51 changed files with 2224 additions and 3270 deletions

44
Dockerfile Normal file
View File

@@ -0,0 +1,44 @@
# 多阶段构建 - 第一阶段:编译(使用已安装的镜像)
FROM golang:alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata
ENV GO111MODULE=on
ENV GOPROXY=https://goproxy.cn,direct
ENV CGO_ENABLED=0
ENV GOTOOLCHAIN=auto
ENV GOPRIVATE=gitea.com/red-future/common
# 配置git使用私有Gitea仓库带Token认证
RUN git config --global url."http://x-token-auth:619679cd366aefea3a50f0622d842a41f2209e08595767bba49c3836ef57d415@116.204.74.41:3000/red-future/common.git".insteadOf "https://gitea.com/red-future/common.git" && \
git config --global credential.helper store
WORKDIR /build
# 复制父目录的 common 模块(因为 go.mod 中使用了本地 replace)
#COPY ../common /build/common
COPY . .
RUN go mod download && go mod tidy
RUN go build -ldflags="-s -w" -o main ./main.go
# 第二阶段:运行
FROM alpine:3.19
ENV TIME_ZONE=Asia/Shanghai
RUN apk add --no-cache ca-certificates tzdata && \
ln -sf /usr/share/zoneinfo/$TIME_ZONE /etc/localtime
WORKDIR /app
# 复制编译好的二进制文件
COPY --from=builder /build/main .
COPY --from=builder /build/config.yml ./
# 创建日志目录
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
EXPOSE 3009
CMD ["./main"]

View File

@@ -1,54 +1,30 @@
# Prompts-Core 提示词核心服务
## 项目简介
Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。
# prompts-core提示词服务[2026.5.12前,暂时弃置]
## 核心功能
1. **提示词构建引擎**
支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。
2. **智能会话管理**
基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。
3. **多模型协议适配**
动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。
4. **文件与技能集成**
自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。
5. **异步任务调度**
支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。
## 1. 功能范围(当前阶段)
- 仅做提示词配置的基础 CRUD最小可用版本
- 表:`prompts_model_prompt`
## 技术架构
- 开发语言Go 1.26.0
- Web 框架GoFrame v2.10.0
- 核心存储Redis会话缓存
- 服务组件Consul服务注册、Jaeger链路追踪
- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型
## 2. 接口
> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。
## 快速开始
### 环境要求
Go 1.26+、Redis、已部署模型网关服务
- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema``modelName` 作为实际调用的网关模型;结合前端 `form(role/value)``userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages`
- `GET /composeMessagesCallback/prompts-core``model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成)
- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果
- `POST /createPrompt`:创建(默认启用)
- `PUT /updatePrompt`:更新
- `DELETE /deletePrompt`:删除
- `GET /getPrompt`:详情
- `POST /listPrompt`:列表分页
### 启动步骤
1. 克隆项目代码
2. 完成项目配置文件修改
3. 执行命令启动服务:
```bash
go run main.go
```
## 3. 数据库初始化
执行根目录 `update.sql`
## API 接口
### 基础信息
- 服务地址:`http://{host}:3009`
- 请求类型:`application/json`
- 认证方式:请求头携带 `Authorization``X-User`
## 4. 运行配置
配置文件:`config.yml`
### 核心接口
1. **提示词拼接接口**
- 地址:`POST /composeMessages`
- 功能:构建提示词并调用模型服务,同步返回结果
2. **任务状态查询**
- 地址:`GET /getComposeTask`
- 功能:根据任务 ID 查询处理状态与结果
3. **任务回调接口**
- 地址:`GET /composeMessagesCallback/prompts-core`
- 功能:接收模型服务处理完成回调
4. **会话同步接口**
- 地址:`POST /sessionCallback`
- 功能:同步更新会话上下文历史
### 新增说明
- `prompts_model_prompt` 去除了 `limit_length`
- 新增 `response_json_schema`
- 新增任务记录表 `prompts_compose_task`
- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback`
- `model-gateway` 实际回调地址`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core`

View File

@@ -1,30 +0,0 @@
package util
import (
"context"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
// address 格式如 ":3009",去掉冒号
if strings.HasPrefix(address, ":") {
return address[1:]
}
return "8080"
}
// GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, modelType int) string {
key := "modelPrompts.types." + gconv.String(modelType)
return g.Cfg().MustGet(ctx, key, "").String()
}
// GetBuildPrompt 获取节点构建提示词
func GetBuildPrompt(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
}

View File

@@ -1,96 +0,0 @@
package util
import (
"path/filepath"
"regexp"
"strings"
)
var (
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
AllowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
// BannedExtensions 禁止的文件扩展名
BannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
multiNewlines = regexp.MustCompile(`\n{3,}`)
)
// SanitizeURL 清洗 URL 字符串
func SanitizeURL(raw string) string {
s := strings.TrimSpace(raw)
s = strings.Trim(s, "`\"")
return s
}
// CleanSymbols 清洗文本中的控制字符和多余空行
func CleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = multiNewlines.ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
// IsBannedExtension 判断是否为禁止的文件扩展名
func IsBannedExtension(url string) bool {
ext := extractExtension(url)
return BannedExtensions[ext]
}
// IsZipExtension 判断是否为 zip 文件
func IsZipExtension(url string) bool {
ext := extractExtension(url)
return ext == ".zip"
}
// IsReadableContentType 判断是否为可读的文本类型
func IsReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range AllowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}
// extractExtension 提取文件扩展名并清理查询参数
func extractExtension(url string) string {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext
}

View File

@@ -1,67 +0,0 @@
package util
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
func AsyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
func ForwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
setHeaderFromContext(headers, ctx, "Authorization", "token")
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
fallbackToRequestHeaders(headers, ctx)
return headers
}
// setHeaderFromContext 从上下文中设置 header
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
headers[headerKey] = value
}
}
// fallbackToRequestHeaders 从请求头中获取作为兜底
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
r := g.RequestFromCtx(ctx)
if r == nil {
return
}
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}

View File

@@ -1,151 +0,0 @@
package util
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ParseOutput 解析模型输出为 JSON 格式
func ParseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshal 将对象序列化为 JSON 字符串,失败时返回空对象
func MustMarshal(v any) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
func GvarToMap(v *gvar.Var) map[string]any {
if v == nil || v.IsNil() {
return nil
}
result := make(map[string]any)
// 方法1尝试获取 map 值
if m := v.Map(); len(m) > 0 {
return m
}
// 方法2尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), &result)
if len(result) > 0 {
return result
}
}
// 方法3尝试获取 interface 再转换
if val := v.Val(); val != nil {
switch val.(type) {
case map[string]any:
return val.(map[string]any)
default:
data, _ := json.Marshal(val)
json.Unmarshal(data, &result)
}
}
return result
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}

View File

@@ -1,130 +0,0 @@
package util
import (
"context"
"net"
"strings"
"github.com/gogf/gf/v2/frame/g"
)
// GetLocalIP 获取本机有效的局域网 IPv4 地址
func GetLocalIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
var validIPs []string
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP
if isIPValid(ip) {
validIPs = append(validIPs, ip.String())
}
}
// 优先返回非 169.254.x.x 的 IP
for _, ip := range validIPs {
if !strings.HasPrefix(ip, "169.254.") {
return ip
}
}
// 其次返回 169.254.x.x最后的选择
if len(validIPs) > 0 {
return validIPs[0]
}
return "127.0.0.1"
}
// isIPValid 判断 IP 是否有效
func isIPValid(ip net.IP) bool {
// 不是 loopback (127.0.0.1)
if ip.IsLoopback() {
return false
}
// 是 IPv4
if ip.To4() == nil {
return false
}
// 不是链路本地地址 (169.254.0.0/16)
if ip[0] == 169 && ip[1] == 254 {
return false
}
// 不是组播地址
if ip.IsMulticast() {
return false
}
// 不是未指定地址 (0.0.0.0)
if ip.IsUnspecified() {
return false
}
return true
}
// GetLocalAddress 获取局域网地址IP:端口)
func GetLocalAddress(ctx context.Context) string {
ip := GetLocalIP()
port := GetServerPort(ctx)
if port == "80" || port == "443" {
return ip
}
return ip + ":" + port
}
// GetSchemaFromRequest 从当前请求中获取协议http/https
func GetSchemaFromRequest(ctx context.Context) string {
r := g.RequestFromCtx(ctx)
if r == nil {
return "http"
}
// 1. 代理场景X-Forwarded-Proto
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
// 2. 代理场景X-Forwarded-Scheme
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
return proto
}
// 3. TLS 连接(直接 HTTPS
if r.TLS != nil {
return "https"
}
// 4. 默认 HTTP这行很重要
return "http" // ← 确保有这行
}
// GetLocalBaseURL 获取局域网基础 URL动态协议 + IP + 端口)
func GetLocalBaseURL(ctx context.Context) string {
schema := GetSchemaFromRequest(ctx)
addr := GetLocalAddress(ctx)
return schema + "://" + addr
}
// GetCallbackURL 获取回调地址(完整 URL
func GetCallbackURL(ctx context.Context, path string) string {
baseURL := GetLocalBaseURL(ctx)
// 确保 path 以 / 开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}

View File

@@ -1,229 +0,0 @@
package util
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/gogf/gf/v2/container/gvar"
)
var (
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
punctRegex = regexp.MustCompile(`[[:punct:]]`)
)
// TokenConfig Token计算配置
type TokenConfig struct {
ZhRatio float64 `json:"zh_ratio"`
EnRatio float64 `json:"en_ratio"`
SpaceRatio float64 `json:"space_ratio"`
PunctuationRatio float64 `json:"punctuation_ratio"`
MaxWindowSize int `json:"max_window_size"`
ReserveRatio float64 `json:"reserve_ratio"`
MinReserve int `json:"min_reserve"`
}
// CalculateTokens 计算文本token数
func CalculateTokens(text string, tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 0
}
if text == "" {
return 0
}
zhCount := countChineseChars(text)
enCount := countEnglishWords(text)
spaceCount := strings.Count(text, " ")
punctCount := countPunctuation(text)
totalTokens := int(
float64(zhCount)*config.ZhRatio +
float64(enCount)*config.EnRatio +
float64(spaceCount)*config.SpaceRatio +
float64(punctCount)*config.PunctuationRatio,
)
return totalTokens
}
// CountToken 计算token是否超出窗口限制
// 返回: true - 未超出(可用), false - 已超出(不可用)
func CountToken(text string, tokenConfig any) bool {
config := parseConfig(tokenConfig)
if config == nil {
return false
}
estimatedTokens := CalculateTokens(text, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
return estimatedTokens <= availableWindow
}
// GetAvailableWindow 获取可用窗口大小
func GetAvailableWindow(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
reserve := reserveByRatio
if config.MinReserve > reserve {
reserve = config.MinReserve
}
available := config.MaxWindowSize - reserve
if available < 0 {
available = 0
}
return available
}
// GetMaxWindowSize 获取模型最大窗口大小
func GetMaxWindowSize(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
return config.MaxWindowSize
}
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
// 返回: isValid, exceedTokens, error
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return true, 0, nil
}
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
if totalTokens > availableWindow {
return false, totalTokens - availableWindow, nil
}
return true, 0, nil
}
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
// 返回: 需要拆分的批次数, 每批的token数, 错误
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return 1, nil, nil
}
availableWindow := GetAvailableWindow(tokenConfig)
batches := 1
currentTokens := 0
batchTokens := make([]int, 0)
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
itemTokens := CalculateTokens(itemStr, tokenConfig)
if currentTokens+itemTokens > availableWindow {
batchTokens = append(batchTokens, currentTokens)
batches++
currentTokens = itemTokens
} else {
currentTokens += itemTokens
}
}
if currentTokens > 0 {
batchTokens = append(batchTokens, currentTokens)
}
return batches, batchTokens, nil
}
// parseConfig 解析配置
func parseConfig(tokenConfig any) *TokenConfig {
if tokenConfig == nil {
return nil
}
switch v := tokenConfig.(type) {
case *gvar.Var:
return parseGVarConfig(v)
case map[string]any:
return parseMapConfig(v)
case *TokenConfig:
return v
case TokenConfig:
return &v
default:
return nil
}
}
// parseGVarConfig 解析 GVar 配置
func parseGVarConfig(v *gvar.Var) *TokenConfig {
if v.IsNil() {
return nil
}
mapVal := v.Map()
if mapVal == nil {
return nil
}
config := &TokenConfig{}
data, _ := json.Marshal(mapVal)
json.Unmarshal(data, config)
return config
}
// parseMapConfig 解析 Map 配置
func parseMapConfig(v map[string]any) *TokenConfig {
config := &TokenConfig{}
data, _ := json.Marshal(v)
json.Unmarshal(data, config)
return config
}
// countChineseChars 统计中文字符数量
func countChineseChars(text string) int {
count := 0
for _, r := range text {
if unicode.Is(unicode.Han, r) {
count++
}
}
return count
}
// countEnglishWords 统计英文单词数量
func countEnglishWords(text string) int {
return len(enWordRegex.FindAllString(text, -1))
}
// countPunctuation 统计标点符号数量
func countPunctuation(text string) int {
return len(punctRegex.FindAllString(text, -1))
}
// calculateUserFormTokens 计算 UserForm 总 token 数
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
totalTokens := 0
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
totalTokens += CalculateTokens(itemStr, tokenConfig)
}
return totalTokens
}

View File

@@ -7,7 +7,7 @@ server:
database:
default:
- type: "pgsql"
host: "116.204.74.41"
host: "192.168.0.169"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
@@ -26,41 +26,21 @@ database:
updatedAt: "updated_at" # (可选)自动更新时间字段名称
deletedAt: "deleted_at" # (可选)软删除时间字段名称
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
model_gateway:
- type: "pgsql"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
name: "model-gateway"
prefix: ""
role: "master"
debug: true
dryRun: false
charset: "utf8"
timezone: "Asia/Shanghai"
maxIdle: 5
maxOpen: 20
maxLifetime: "30s"
maxIdleConnTime: "30s"
createdAt: "created_at"
updatedAt: "updated_at"
deletedAt: "deleted_at"
timeMaintainDisabled: false
redis:
default:
address: 192.168.3.30:6379
address: 192.168.0.169:6379
db: 0
consul:
address: 192.168.3.30:8500
address: 192.168.0.169:8500
jaeger:
addr: 192.168.3.30:4318
addr: 192.168.0.169:4318
task:
waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒)
waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒)
pollIntervalMillis: 500 # 同步等待期间,轮询本地任务表 / 网关状态的时间间隔(毫秒)
session:
maxRounds: 10 # 最大轮数
@@ -81,28 +61,70 @@ promptsRetry:
modelPrompts:
types:
100: |
1: |
你是一个智能文字处理助手,专注于文本理解、文本创作、文本优化与语言表达任务,能够根据不同场景完成文章撰写、商业文案、报告总结、邮件通知、脚本创作、内容改写、信息提炼、语言翻译等多种文字处理工作,并能够理解上下文语义关系,保持内容逻辑完整、结构清晰、表达自然。
在执行文本任务时,你需要以专业内容创作者、编辑顾问、语言优化专家的身份完成输出,严格保证语言准确性、逻辑连贯性、表达一致性与阅读体验,根据不同用户场景自动适配正式、口语化、专业化、营销化等表达风格,同时避免空洞表达、重复描述与机械化生成内容。
当用户提供具体需求时,需要结合用户输入、上下文信息、参数条件与目标场景生成最终文本结果;若涉及改写、扩写、摘要、总结、标题、营销内容等任务,需要保证核心语义不偏离,并根据用户真实目的完成结构化输出。
200: |
2: |
你是一个智能图片处理助手,专注于视觉内容生成、图像编辑、画面分析与风格控制任务,能够根据文字描述生成不同风格的图片内容,包括写实、插画、动漫、水彩、电影感、商业海报等多种视觉形式,并支持图片局部修改、风格迁移、画面扩展、背景处理与视觉增强等操作。
在执行图片相关任务时,你需要以专业视觉设计师、插画师、摄影指导、美术导演的身份进行画面构建,重点关注主体构图、色彩关系、光影氛围、镜头语言、视觉层次与整体风格统一性,确保生成结果具备明确视觉主题与稳定审美表现,而不是简单关键词堆砌。
当用户提供图片需求时,需要结合用户描述、场景用途、风格方向、尺寸比例、主体元素、氛围要求等信息生成完整视觉方案;若存在图片编辑任务,则必须保留原图核心特征,仅对用户指定区域或效果进行修改。
300: |
3: |
你是一个智能音频处理助手,专注于语音生成、语音识别、音频分析与声音编辑任务,能够完成文字转语音、语音转文本、多语言识别、音频降噪、音色处理、混音剪辑、情绪识别与声音特征分析等多种音频相关工作,并能够根据不同场景匹配对应语音风格与声音表现形式。
在执行音频任务时,你需要以专业配音导演、声音工程师、语音分析专家、后期音频制作人员的身份进行处理,重点保证语音自然度、情绪一致性、识别准确率、音频清晰度与输出稳定性,同时确保不同格式、采样率与播放场景下具备良好兼容性。
当用户提供具体音频需求时,需要结合音色、语速、语言类型、情绪风格、背景环境、输出格式等参数完成对应处理;若涉及语音识别或音频分析,则需要尽可能保留原始语义与声音特征,并明确标注不确定内容。
400: |
4: |
你是一个智能向量化处理助手,专注于文本向量化、语义检索、知识索引、相似度计算与语义聚类任务,能够将文本内容转换为高维语义向量,并基于向量相似度完成语义搜索、知识召回、内容聚类、文档匹配与知识库构建等处理流程。
在执行向量化任务时你需要以语义检索工程师、知识库架构师、AI检索系统专家的身份进行处理重点保证语义表达准确性、向量一致性、检索稳定性与召回有效性同时确保不同文本之间的语义关系能够被正确表达与计算。
当用户提供文本集合、知识内容或检索需求时,需要结合文本上下文、主题方向、检索目标、相似度要求与业务场景生成最终结果;若涉及聚类或知识库构建,则必须明确类别关系、索引结构与召回逻辑。
500: |
5: |
你是一个全模态智能处理助手,能够同时理解、分析与生成文本、图片、音频、视频等多种模态内容,并支持跨模态转换、多模态融合推理、联合内容生成与复杂场景交互,能够根据不同输入形式自动匹配最合理的处理策略与输出方式。
在执行多模态任务时你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性避免出现跨模态语义断裂或输出不一致的问题。
当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。
nodePrompts: |
buildProject:
types:
1: |
你是专业的JSON结构生成专家必须严格遵守以下全部规则。
【强制规则】
必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
完整读取UserForm所有字段不得忽略任何字段
如果有skill相关内容必须完整的将内容拼接到system角色描述中
理解全部语义后再输出,禁止断章取义;
UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
【优先级】
用户自然语言 > UserForm > Form
UserForm与Form同名字段时仅保留UserForm值
Form仅用于组装system角色内容。
【表单处理】
Form系统提示词、默认参数、基础配置 → 专属填充system角色
UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和示例10条文案各配5张图 → 总50张parameters.n=50,用户没有相关数量必须默认1
图片尺寸为空时自动填充size=1024*1024。
【结构铁律】
严格沿用固定输出结构,不增删字段或修改层级;
messages元素必须按结构返回
禁止将role对象转为字符串、禁止嵌套错乱
输出纯净JSON无多余转义符、无换行符、无额外字符
所有括号、引号必须成对闭合保证JSON合法。
【参数赋值】
model固定沿用传入值
返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
history历史信息必须结合UserForm里的内容对用户描述部分进行补充
从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
【输出要求】
仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
完整合UserForm全部字段语义到user描述
生成后自检JSON语法、结构、数量错误则自动重新生成。
【输出结构】
%s
【字段映射】
%s
【完整输入信息】
%s
直接输出最终JSON
2: |
你是流程路由助手你的任务是根据上下文选择一个正确的节点ID返回。
规则:
1. 只允许从下面的可选节点ID列表中选择一个返回
@@ -112,41 +134,3 @@ nodePrompts: |
%s
上下文内容:
%s
#你是专业的JSON结构生成专家必须严格遵守以下全部规则。
# 【强制规则】
# 必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
# 完整读取UserForm所有字段不得忽略任何字段
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中
# 理解全部语义后再输出,禁止断章取义;
# UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
# 【优先级】
# 用户自然语言 > UserForm > Form
# UserForm与Form同名字段时仅保留UserForm值
# Form仅用于组装system角色内容。
# 【表单处理】
# Form系统提示词、默认参数、基础配置 → 专属填充system角色
# UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
# 自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和用户没有相关数量必须默认1
# 图片尺寸为空时自动填充size=1024*1024。
# 【结构铁律】
# 严格沿用固定输出结构,不增删字段或修改层级;
# messages元素必须按结构返回
# 禁止将role对象转为字符串、禁止嵌套错乱
# 输出纯净JSON无多余转义符、无换行符、无额外字符
# 所有括号、引号必须成对闭合保证JSON合法。
# 【参数赋值】
# model固定沿用传入值
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充
# 从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
# 【输出要求】
# 仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
# 完整合UserForm全部字段语义到user描述
# 生成后自检JSON语法、结构、数量错误则自动重新生成。
# 【输出结构】
# %s
# 【完整输入信息】
# %s
# 直接输出最终JSON

View File

@@ -5,8 +5,3 @@ const (
ComposeStatusSuccess = "success"
ComposeStatusFailed = "failed"
)
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建
)

View File

@@ -1,12 +1,8 @@
package public
const (
DbNameModelGateway = "model_gateway" //数据库名称
)
const (
TableNameModel = "asynch_models" // 模型表
TableNamePromptConfig = "prompts_model_prompt" // 模型提示词配置表prompts-core
TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表
TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表
TableNameProviderProtocol = "prompts_provider_protocol"
)

View File

@@ -2,17 +2,19 @@ package controller
import (
"context"
"prompts-core/model/dto"
promptService "prompts-core/service/prompt"
"prompts-core/model/dto"
"prompts-core/service"
"gitea.com/red-future/common/beans"
)
type session struct{}
// Session 提示词会话控制器
// Prompt 提示词配置控制器
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return promptService.SessionCallback(ctx, req)
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
return service.Session.SessionCallback(ctx, req)
}

View File

@@ -1,29 +0,0 @@
package controller
import (
"context"
"prompts-core/model/dto"
promptService "prompts-core/service/prompt"
)
type prompt struct{}
// Prompt 提示词配置控制器
var Prompt = new(prompt)
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
return promptService.ComposeMessages(ctx, req)
}
// Callback model-gateway 提示词回调
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) {
err = promptService.Callback(ctx, req)
return
}
// GetComposeTask 查询拼接任务结果
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId)
}

View File

@@ -0,0 +1,69 @@
package controller
import (
"context"
"prompts-core/model/dto"
"prompts-core/service"
"gitea.com/red-future/common/beans"
)
type prompt struct{}
// Prompt 提示词配置控制器
var Prompt = new(prompt)
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
return service.Prompt.ComposeMessages(ctx, req)
}
// ComposeMessagesCallback model-gateway 提示词回调
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Callback(ctx, req)
return
}
// GetComposeTask 查询拼接任务结果
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return service.Prompt.GetComposeTask(ctx, req.TaskId)
}
// CreatePrompt 添加配置(默认启用)
func (c *prompt) CreatePrompt(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
return service.Prompt.Create(ctx, req)
}
// UpdatePrompt 更新配置
func (c *prompt) UpdatePrompt(ctx context.Context, req *dto.UpdatePromptReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Update(ctx, req)
return
}
// DeletePrompt 删除配置
func (c *prompt) DeletePrompt(ctx context.Context, req *dto.DeletePromptReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Delete(ctx, req.ID)
return
}
// GetPrompt 获取配置详情
func (c *prompt) GetPrompt(ctx context.Context, req *dto.GetPromptReq) (res *dto.GetPromptRes, err error) {
m, err := service.Prompt.Get(ctx, req.ID)
if err != nil {
return nil, err
}
return &dto.GetPromptRes{Prompt: m}, nil
}
// ListPrompt 配置列表
func (c *prompt) ListPrompt(ctx context.Context, req *dto.ListPromptReq) (res *dto.ListPromptRes, err error) {
list, total, err := service.Prompt.List(ctx, int(req.Page.PageNum), int(req.Page.PageSize), req.ModelTypeId, req.ModelType)
if err != nil {
return nil, err
}
return &dto.ListPromptRes{
List: list,
Total: total,
}, nil
}

View File

@@ -2,94 +2,115 @@ package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ComposeSession = &composeSessionDao{}
type composeSessionDao struct{}
// Insert 插入
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
var m = new(entity.ComposeSession)
err = gconv.Struct(req, &m)
func (d *composeSessionDao) Insert(ctx context.Context, m *entity.ComposeSession) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).Data(m).Insert()
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Insert(m)
if err != nil {
return
return 0, err
}
return r.LastInsertId()
}
// Update 更新
func (d *composeSessionDao) Update(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
func (d *composeSessionDao) Update(ctx context.Context, m *entity.ComposeSession) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.Id, m.Id).
Data(m).
OmitEmpty().
Data(&req).
Where(entity.ComposeSessionCol.Id, req.Id).
Update()
if err != nil {
return
return 0, err
}
return r.RowsAffected()
}
// List 查询编排会话列表
func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession, page, size int, fields ...string) (list []*entity.ComposeSession, total int, err error) {
if page <= 0 {
page = 1
func (d *composeSessionDao) List(ctx context.Context, page, size int, where map[string]any) (list []*entity.ComposeSession, total int, err error) {
model := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where("deleted_at IS NULL")
// 动态拼接查询条件
for k, v := range where {
model = model.Where(k, v)
}
if size <= 0 {
size = 10
}
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Fields(fields).
OmitEmpty()
model.Where(entity.ComposeSessionCol.Creator, req.Creator)
model.Where(entity.ComposeSessionCol.SessionId, req.SessionId)
model.OrderDesc(entity.ComposeSessionCol.CreatedAt)
model.Page(page, size)
r, total, err := model.AllAndCount(false)
// 查询总数
total, err = model.Count()
if err != nil {
return
return nil, 0, err
}
err = r.Structs(&list)
// 分页查询
err = model.Order("created_at DESC").
Page(page, size).
Scan(&list)
return
}
// Get 查询编排会话
func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, fields ...string) (m *entity.ComposeSession, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Fields(fields).One()
func (d *composeSessionDao) GetListBySessionId(ctx context.Context, sessionId string, limit int) ([]*entity.ComposeSession, error) {
var sessions []*entity.ComposeSession
err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.SessionId, sessionId).
WhereNull(entity.ComposeSessionCol.DeletedAt).
OrderDesc(entity.ComposeSessionCol.Id).
Limit(limit).
Scan(&sessions)
if err != nil {
return nil, err
}
// 反转成时间正序
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.ComposeSession, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.Id, Id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return
return nil, nil
}
err = r.Struct(&m)
return
}
// Delete 删除编排会话
func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Delete()
func (d *composeSessionDao) GetBySessionId(ctx context.Context, sessionId string) (m *entity.ComposeSession, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.SessionId, sessionId).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *composeSessionDao) DeleteBySessionId(ctx context.Context, sessionId string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.SessionId, sessionId).
Data(map[string]any{
entity.ComposeSessionCol.DeletedAt: "NOW()",
}).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}

View File

@@ -2,54 +2,47 @@ package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ComposeTask = &composeTaskDao{}
type composeTaskDao struct{}
// Insert 插入
func (d *composeTaskDao) Insert(ctx context.Context, req *entity.ComposeTask) (id int64, err error) {
var m = new(entity.ComposeTask)
err = gconv.Struct(req, &m)
func (d *composeTaskDao) Insert(ctx context.Context, m *entity.ComposeTask) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).Data(m).Insert()
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
Insert(m)
if err != nil {
return
return 0, err
}
return r.LastInsertId()
}
// Get 获取
func (d *composeTaskDao) Get(ctx context.Context, req *entity.ComposeTask, fields ...string) (m *entity.ComposeTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
OmitEmpty().
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
Fields(fields).One()
func (d *composeTaskDao) GetByTaskId(ctx context.Context, taskId string) (m *entity.ComposeTask, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).
Where(entity.ComposeTaskCol.TaskId, taskId).
One()
if err != nil {
return
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
// Update 更新
func (d *composeTaskDao) Update(ctx context.Context, req *entity.ComposeTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
OmitEmpty().
Data(&req).
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
func (d *composeTaskDao) UpdateByTaskId(ctx context.Context, taskId string, data map[string]any) (rows int64, err error) {
data[entity.ComposeTaskCol.Updater] = ""
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).
Where(entity.ComposeTaskCol.TaskId, taskId).
Data(data).
Update()
if err != nil {
return
return 0, err
}
return r.RowsAffected()
}

View File

@@ -2,27 +2,62 @@ package dao
import (
"context"
"fmt"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
)
var Model = &modelDao{}
type modelDao struct{}
// Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.ModelName, modelName).
One()
if err != nil {
return
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.IsChatModel, 1).
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
// GetBySuperAdmin 查询超级管理员tenant_id=1的模型
func (d *modelDao) GetBySuperAdmin(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
sql := fmt.Sprintf("SELECT * FROM %s WHERE model_name = ? AND tenant_id = 1 AND deleted_at IS NULL LIMIT 1", public.TableNameModel)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, modelName)
if err != nil {
return nil, err
}
if len(r) == 0 {
return nil, nil
}
err = r[0].Struct(&m)
return
}

97
dao/prompt_dao.go Normal file
View File

@@ -0,0 +1,97 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var Prompt = &promptDao{}
type promptDao struct{}
func (d *promptDao) Insert(ctx context.Context, m *entity.PromptConfig) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Data(m).Insert()
if err != nil {
return 0, err
}
return r.LastInsertId()
}
func (d *promptDao) UpdateByID(ctx context.Context, id int64, data map[string]any) (rows int64, err error) {
// 触发 gfdb 的 updateHook 自动填充 updater需要显式带 updater 字段
data[entity.PromptConfigCol.Updater] = ""
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
Data(data).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *promptDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *promptDao) GetByID(ctx context.Context, id int64) (m *entity.PromptConfig, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *promptDao) GetLatestEnabledByModelTypeID(ctx context.Context, modelTypeID int) (m *entity.PromptConfig, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where("deleted_at IS NULL").
Where(entity.PromptConfigCol.ModelTypeId, modelTypeID).
Where(entity.PromptConfigCol.Enabled, 1).
OrderDesc(entity.PromptConfigCol.CreatedAt).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *promptDao) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
model := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Where("deleted_at IS NULL").OrderDesc(entity.PromptConfigCol.CreatedAt)
if modelTypeID != nil && *modelTypeID > 0 {
model = model.Where(entity.PromptConfigCol.ModelTypeId, *modelTypeID)
}
if modelTypeLike != "" {
model = model.WhereLike(entity.PromptConfigCol.ModelType, "%"+modelTypeLike+"%")
}
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -1,98 +0,0 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ProviderProtocol = &providerProtocolDao{}
type providerProtocolDao struct{}
// Insert 新增协议配置
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
var m = new(entity.ProviderProtocol)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Insert(m)
if err != nil {
return 0, err
}
return r.LastInsertId()
}
// Get 查询协议配置
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
NoTenantId(ctx).
OmitEmpty().
Where(entity.ProviderProtocolCol.Id, req.Id).
Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
Where(entity.ProviderProtocolCol.Status, 1).
Fields(fields).One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&res)
return
}
// List 列表查询
func (d *providerProtocolDao) List(ctx context.Context, req *entity.ProviderProtocol, page, size int, fields ...string) (list []*entity.ProviderProtocol, total int, err error) {
if page <= 0 {
page = 1
}
if size <= 0 {
size = 10
}
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Fields(fields).
OmitEmpty()
model.Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName)
model.Where(entity.ProviderProtocolCol.Status, req.Status)
model.OrderDesc(entity.ProviderProtocolCol.CreatedAt)
model.Page(page, size)
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&list)
return
}
// Update 更新协议配置
func (d *providerProtocolDao) Update(ctx context.Context, req *entity.ProviderProtocol) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
OmitEmpty().
Where(entity.ProviderProtocolCol.Id, req.Id).
Data(req).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Delete 软删除协议配置
func (d *providerProtocolDao) Delete(ctx context.Context, id int64) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Where(entity.ProviderProtocolCol.Id, id).
Data(map[string]any{
entity.ProviderProtocolCol.DeletedAt: "NOW()",
}).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}

8
go.mod
View File

@@ -3,17 +3,12 @@ module prompts-core
go 1.26.0
require (
gitea.com/red-future/common v0.0.19
gitea.com/red-future/common v0.0.21
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
github.com/gogf/gf/v2 v2.10.0
)
require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
)
require (
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/armon/go-metrics v0.4.1 // indirect
@@ -68,7 +63,6 @@ require (
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/redis/go-redis/v9 v9.12.1 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/gjson v1.19.0
github.com/tiger1103/gfast-token v1.0.10 // indirect
github.com/vcaesar/cedar v0.30.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect

10
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
gitea.com/red-future/common v0.0.21 h1:8w30HmCVmFG/hphH3ODJs1KxDEGmRpq+/PXI0pQjJKc=
gitea.com/red-future/common v0.0.21/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -288,12 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=

View File

@@ -4,9 +4,10 @@ import (
"context"
"os"
"os/signal"
"prompts-core/controller"
"syscall"
"prompts-core/controller"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
_ "gitea.com/red-future/common/swagger"
@@ -19,13 +20,14 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
defer jaeger.ShutDown(ctx)
// 注册路由
http.RouteRegister([]interface{}{
controller.Prompt,
controller.Session,
})
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 http server
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
<-quit

View File

@@ -0,0 +1,51 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
type Message struct {
Role string `json:"role" dc:"角色system/user/assistant"`
Content any `json:"content" dc:"消息内容"`
}
type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
Form map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
}
type ComposeMessagesRes struct {
Messages any `json:"messages,omitempty" dc:"最终消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
Text string `json:"text" dc:"文本结果"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
}
type GetComposeTaskReq struct {
g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"`
TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
}
type GetComposeTaskRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}

View File

@@ -7,6 +7,3 @@ type SessionCallbackReq struct {
Text string `json:"text" dc:"文本结果"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
}
type SessionCallbackRes struct {
}

View File

@@ -1,60 +0,0 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
Form map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
}
type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
}
/*
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
*/
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
Text string `json:"text" dc:"文本结果"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
}
type CallbackRes struct {
}
type GetComposeTaskReq struct {
g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"`
TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
}
type GetComposeTaskRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}

63
model/dto/prompt_dto.go Normal file
View File

@@ -0,0 +1,63 @@
package dto
import (
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
// CreatePromptReq 添加提示词配置(默认启用)
type CreatePromptReq struct {
g.Meta `path:"/createPrompt" method:"post" tags:"提示词管理" summary:"创建提示词配置" dc:"创建新的模型提示词配置(默认启用)"`
ModelTypeId int `p:"modelTypeId" json:"modelTypeId" v:"required#modelTypeId不能为空" dc:"模型分类ID"`
ModelType string `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类别/模型类型"`
PromptInfo any `p:"promptInfo" json:"promptInfo" v:"required#promptInfo不能为空" dc:"数据库定义的表单规则数据JSON"`
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" v:"required#responseJsonSchema不能为空" dc:"模型返回表单 JSON 格式约束"`
// Version 预留字段:先不使用,但表结构保留
Version string `p:"version" json:"version" dc:"版本号(预留)"`
}
type CreatePromptRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// UpdatePromptReq 更新提示词配置
type UpdatePromptReq struct {
g.Meta `path:"/updatePrompt" method:"put" tags:"提示词管理" summary:"更新提示词配置" dc:"更新指定ID的提示词配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID可选更新"`
ModelType *string `p:"modelType" json:"modelType" dc:"模型类别/模型类型(可选更新)"`
PromptInfo any `p:"promptInfo" json:"promptInfo" dc:"数据库定义的表单规则数据JSON可选更新"`
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" dc:"模型返回表单 JSON 格式约束(可选更新)"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-禁用1-启用(可选更新)"`
Version *string `p:"version" json:"version" dc:"版本号(预留,可选更新)"`
}
// DeletePromptReq 删除提示词配置
type DeletePromptReq struct {
g.Meta `path:"/deletePrompt" method:"delete" tags:"提示词管理" summary:"删除提示词配置" dc:"删除指定ID的提示词配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
// GetPromptReq 获取提示词配置详情
type GetPromptReq struct {
g.Meta `path:"/getPrompt" method:"get" tags:"提示词管理" summary:"获取提示词配置" dc:"根据ID获取提示词配置详情"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
type GetPromptRes struct {
Prompt any `json:"prompt" dc:"提示词配置详情"`
}
// ListPromptReq 配置列表
type ListPromptReq struct {
g.Meta `path:"/listPrompt" method:"post" tags:"提示词管理" summary:"提示词配置列表" dc:"分页获取提示词配置列表"`
Page *beans.Page `p:"page" json:"page" dc:"分页参数"`
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID可选"`
ModelType string `p:"modelType" json:"modelType" dc:"模型类型名称(可选,模糊查询)"`
}
type ListPromptRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -2,37 +2,6 @@ package entity
import "gitea.com/red-future/common/beans"
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig any `orm:"token_config" json:"tokenConfig"`
}
type asynchModelCol struct {
beans.SQLBaseCol
ModelName string
@@ -44,7 +13,7 @@ type asynchModelCol struct {
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
TokenMapping string
Prompt string
IsPrivate string
IsChatModel string
@@ -58,9 +27,6 @@ type asynchModelCol struct {
RetryQueueMaxSecs string
AutoCleanSeconds string
Remark string
IsOwner string
OperatorName string
TokenConfig string
}
var AsynchModelCol = asynchModelCol{
@@ -74,7 +40,7 @@ var AsynchModelCol = asynchModelCol{
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
TokenMapping: "token_mapping",
Prompt: "prompt",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
@@ -88,7 +54,32 @@ var AsynchModelCol = asynchModelCol{
RetryQueueMaxSecs: "retry_queue_max_seconds",
AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
}
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
}

View File

@@ -2,14 +2,6 @@ package entity
import "gitea.com/red-future/common/beans"
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
}
type composeSessionCol struct {
beans.SQLBaseCol
SessionId string
@@ -25,3 +17,11 @@ var ComposeSessionCol = composeSessionCol{
ResponseContent: "response_content",
Remark: "remark",
}
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
}

View File

@@ -0,0 +1,45 @@
package entity
import "gitea.com/red-future/common/beans"
type composeTaskCol struct {
beans.SQLBaseCol
TaskId string
ModelName string
SkillName string
LimitWords string
RequestPayload string
CallbackPayload string
ModelResult string
Messages string
Status string
ErrorMessage string
}
var ComposeTaskCol = composeTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
TaskId: "task_id",
ModelName: "model_name",
SkillName: "skill_name",
LimitWords: "limit_words",
RequestPayload: "request_payload",
CallbackPayload: "callback_payload",
ModelResult: "model_result",
Messages: "messages",
Status: "status",
ErrorMessage: "error_message",
}
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
LimitWords int `orm:"limit_words" json:"limitWords"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
CallbackPayload any `orm:"callback_payload" json:"callbackPayload"`
ModelResult any `orm:"model_result" json:"modelResult"`
Messages any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
}

View File

@@ -0,0 +1,39 @@
package entity
import "gitea.com/red-future/common/beans"
type promptConfigCol struct {
beans.SQLBaseCol
ModelTypeId string
ModelType string
PromptInfo string
ResponseJsonSchema string
Enabled string
Version string
}
var PromptConfigCol = promptConfigCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelTypeId: "model_type_id",
ModelType: "model_type",
PromptInfo: "prompt_info",
ResponseJsonSchema: "response_json_schema",
Enabled: "enabled",
Version: "version",
}
// PromptConfig 模型提示词配置
//
// 说明:
// - prompt_info 使用 JSONB 保存(对外用 json 传输)
// - response_json_schema 为模型返回 JSON 格式约束
// - enabled1启用/0禁用
type PromptConfig struct {
beans.SQLBaseDO `orm:",inline"`
ModelTypeId int `orm:"model_type_id" json:"modelTypeId"`
ModelType string `orm:"model_type" json:"modelType"`
PromptInfo any `orm:"prompt_info" json:"promptInfo"`
ResponseJsonSchema any `orm:"response_json_schema" json:"responseJsonSchema"`
Enabled int `orm:"enabled" json:"enabled"`
Version string `orm:"version" json:"version"`
}

View File

@@ -1,54 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
BuildType int `orm:"build_type" json:"buildType"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
GatewayState int `orm:"gateway_state" json:"gatewayState"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
ResultText string `orm:"result_text" json:"resultText"`
Messages any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
}
type composeTaskCol struct {
beans.SQLBaseCol
TaskId string
ModelName string
SkillName string
BuildType string
CallbackUrl string
GatewayState string
RequestPayload string
ResultText string
Messages string
Status string
ErrorMessage string
OssFile string
FileType string
}
var ComposeTaskCol = composeTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
TaskId: "task_id",
ModelName: "model_name",
SkillName: "skill_name",
BuildType: "build_type",
CallbackUrl: "callback_url",
GatewayState: "gateway_state",
RequestPayload: "request_payload",
ResultText: "result_text",
Messages: "messages",
Status: "status",
ErrorMessage: "error_message",
OssFile: "oss_file",
FileType: "file_type",
}

View File

@@ -1,49 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
// ProviderProtocol 模型协议映射配置
type ProviderProtocol struct {
beans.SQLBaseDO `orm:",inherit"`
// 业务字段
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
Capabilities any `orm:"capabilities" json:"capabilities"`
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
}
// providerProtocolCol 列名
type providerProtocolCol struct {
beans.SQLBaseCol
ProviderName string
TargetField string
MergeOrder string
RoleMapping string
ContentMapping string
Capabilities string
RequestTemplate string
SystemPromptTemplate string
Status string
Remark string
}
// ProviderProtocolCol 列名常量
var ProviderProtocolCol = providerProtocolCol{
SQLBaseCol: beans.DefSQLBaseCol,
ProviderName: "provider_name",
TargetField: "target_field",
MergeOrder: "merge_order",
RoleMapping: "role_mapping",
ContentMapping: "content_mapping",
Capabilities: "capabilities",
RequestTemplate: "request_template",
SystemPromptTemplate: "system_prompt_template",
Status: "status",
Remark: "remark",
}

144
service/build_prompt.go Normal file
View File

@@ -0,0 +1,144 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/model/dto"
"prompts-core/model/entity"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
}
// 获取构建提示词
func GetBuildPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
}
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
messages := []map[string]any{}
switch req.BuildType {
//构建提示词请求
case 1:
//1. 构建系统提示词
messages = append(messages, map[string]any{
"role": "system",
"content": promptBuild(ctx, req, model),
})
// 2. 构建历史会话提示词
for _, msg := range history {
role := gconv.String(msg["role"])
content := gconv.String(msg["content"])
if role != "user" && role != "assistant" {
continue
}
messages = append(messages, map[string]any{
"role": role,
"content": content,
})
}
// 3. 当前用户问题(原来的最后一条)
messages = append(messages, map[string]any{
"role": "user",
"content": buildUserPrompt(ctx, req, GetModelPrompt(ctx, model.ModelType)),
})
//构建节点请求
case 2:
messages = append(messages, map[string]any{
"role": "user",
"content": NodeBuid(ctx, req),
})
default:
return nil, errors.New("不支持的构建类型")
}
// 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": map[string]any{
"model": chatModel.ModelName,
"messages": messages,
"stream": false,
},
}, nil
}
// ============================================
// 构建用户提示词
// ============================================
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName,
//数据库提示信息
"promptInfo": prompt,
// 系统表单
"form": req.Form,
// 用户表单
"userForm": req.UserForm,
//文件url
"userFiles": req.UserFiles,
//解读文件(只支持可读类型 如xmljson,yaml
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
//skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
"skills": SkillMdContent(ctx, req.SkillName),
}
return mustMarshal(payload)
}
// promptBuild 提示词构建
func promptBuild(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) string {
// 1. 从配置文件读取提示词模板
promptTpl := GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
// 2. 构建字段映射说明
mappingBytes, _ := json.Marshal(model.RequestMapping)
mappingStr := string(mappingBytes)
var mapping map[string]string
_ = json.Unmarshal(mappingBytes, &mapping)
var fieldDesc strings.Builder
for key, path := range mapping {
fieldDesc.WriteString(fmt.Sprintf("- %s → %s\n", key, path))
}
// 3. 拼接 UserForm 全文(必须完整阅读)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
// 4. 双表单信息
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, formToJSON(req.Form), userFormFullText)
// 5. 格式化最终提示词(替换配置里的 %s
return fmt.Sprintf(promptTpl, mappingStr, fieldDesc.String(), formInfo)
}
// NodeBuid 节点构建
func NodeBuid(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
formStr := formToJSON(req.Form)
userFormStr := formToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

416
service/compose_service.go Normal file
View File

@@ -0,0 +1,416 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
)
// ============================================
// 核心业务流程
// ============================================
// ComposeMessages 拼接提示词主流程
func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
var (
epicycleId int64
taskID string
history []map[string]any
message map[string]any
err error
taskRecord *entity.ComposeTask
)
// 获取模型信息
chatModel, model, err := s.GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 根据构建类型进行判断处理
switch req.BuildType {
//提示词构建
case 1:
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
//1. 获取历史会话
history, err = Session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil // 出错就用空的,不影响主流程
}
// 重试循环
for attempt := 0; attempt <= maxRetryTimes; attempt++ {
if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
}
// 2. 调用推理模型
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, history)
if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue
}
// 3. 保存记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: mustMarshal(req),
Status: public.ComposeStatusPending,
})
if err != nil {
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue
}
// 4. 等待结果
taskRecord, err = s.waitForResult(ctx, taskID)
if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue
}
// 校验结果
message = s.parsePromptBuild(taskRecord, chatModel)
if message != nil && isMessageValid(message) {
break
}
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
message = nil
}
if message == nil {
return nil, errors.New("推理模型调用失败,请稍后再试")
}
//5.创建会话记录
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: message,
})
//节点构建
case 2:
//1. 调用推理模型
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, nil)
if err != nil {
return nil, err
}
//2. 保存相关记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
RequestPayload: mustMarshal(req),
Status: public.ComposeStatusPending,
})
//5. 等待结果
taskRecord, err := s.waitForResult(ctx, taskID)
if err != nil {
return nil, err
}
fmt.Println("构建节点前", taskRecord)
message = s.parseNodeBuild(taskRecord)
fmt.Println("构建节点后", message)
default:
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
Remark: req.Cause,
})
return &dto.ComposeMessagesRes{
EpicycleId: epicycleId,
}, nil
}
return &dto.ComposeMessagesRes{
Messages: message,
EpicycleId: epicycleId,
}, nil
}
func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId)
if err != nil {
return err
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := parseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
entity.ComposeTaskCol.ErrorMessage: err.Error(),
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess,
entity.ComposeTaskCol.Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID)
if err != nil {
return nil, err
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// GetModelMessage 获取模型信息
func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
// 1. 获取当前用户的会话模型
chatModel, err := dao.Model.GetByIsChatModel(ctx)
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
// 2. 获取要构建的模型信息
model, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, nil, err
}
if model == nil {
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
}
return chatModel, model, nil
}
// callInferenceModel 调用推理模型
func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
// 构建推理模型请求
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
// 创建网关任务
taskID, err := createGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
}
return taskID, nil
}
// ============================================
// 步骤6等待结果
// ============================================
func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout)
for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 1. 查数据库
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID)
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
}
}
// 2. 查网关状态
state, err := queryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess,
})
}
case 3: // 网关失败
if record != nil {
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败",
})
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
}
}
}
// parsePromptBuild 解析提示词构建结果BuildType == 1
func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
if taskRecord == nil {
return nil
}
// 1. 解析 Messages
var mapped map[string]any
switch v := taskRecord.Messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &mapped)
}
case string:
json.Unmarshal([]byte(v), &mapped)
case map[string]any:
mapped = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &mapped)
}
// 2. 解析模型 ResponseMapping 获取 content 字段名
contentField := "content" // 默认值
if model != nil {
var respMapping map[string]string
switch v := model.ResponseMapping.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &respMapping)
}
case string:
json.Unmarshal([]byte(v), &respMapping)
case map[string]interface{}:
respMapping = make(map[string]string)
for k, val := range v {
if s, ok := val.(string); ok {
respMapping[k] = s
}
}
}
// 从映射中找到 content 对应的字段名
for k, v := range respMapping {
if strings.Contains(v, "content") {
contentField = k
break
}
}
}
// 3. 提取 content 的值
contentStr, ok := mapped[contentField].(string)
if !ok || contentStr == "" {
return mapped
}
// 4. 解析 content 内的 JSON
var innerData map[string]any
json.Unmarshal([]byte(contentStr), &innerData)
return innerData
}
// parseNodeBuild 解析节点构建结果BuildType == 2
func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
if taskRecord == nil {
return nil
}
var result map[string]any
switch v := taskRecord.Messages.(type) {
case *gvar.Var:
if v != nil {
json.Unmarshal([]byte(v.String()), &result)
}
case string:
json.Unmarshal([]byte(v), &result)
case map[string]any:
result = v
default:
b, _ := json.Marshal(v)
json.Unmarshal(b, &result)
}
return result
}

340
service/files_handle.go Normal file
View File

@@ -0,0 +1,340 @@
package service
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"net/http"
"path/filepath"
"regexp"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// ============================================
// 文件处理(配置直接内联 + zip 支持)
// ============================================
// 允许的文本类 MIME 类型前缀
var allowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
// 禁止的文件扩展名
var bannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
// FetchFileTexts 从 URL 列表获取文件内容(支持 zip 内文件)
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string)
if len(urls) == 0 {
return result
}
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
}
for _, rawURL := range urls {
url := sanitizeURL(rawURL)
if url == "" {
continue
}
if isBannedExtension(url) {
continue
}
if isZipExtension(url) {
zipTexts := fetchZipFileTexts(ctx, client, url)
for k, v := range zipTexts {
result[k] = v
}
continue
}
text, err := fetchFileContent(ctx, client, url)
if err != nil {
continue
}
if text == "" {
continue
}
text = cleanSymbols(text)
result[url] = text
}
return result
}
func isZipExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext == ".zip"
}
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
zipBytes, err := downloadFile(client, url,
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
if err != nil {
return result
}
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
for _, file := range reader.File {
if file.FileInfo().IsDir() {
continue
}
fileName := file.Name
if isBannedExtension(fileName) {
continue
}
if isZipExtension(fileName) {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
contentType := http.DetectContentType(content)
if !isReadableContentType(contentType) {
continue
}
text := cleanSymbols(string(content))
if text == "" {
continue
}
key := url + "::" + fileName
result[key] = text
}
return result
}
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
}
func isBannedExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return bannedExtensions[ext]
}
func isReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range allowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}
func cleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !isReadableContentType(contentType) {
return "", fmt.Errorf("unreadable content-type: %s", contentType)
}
body, err := io.ReadAll(
io.LimitReader(resp.Body,
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
),
)
if err != nil {
return "", err
}
return strings.TrimSpace(string(body)), nil
}
func sanitizeURL(raw string) string {
s := strings.TrimSpace(raw)
s = strings.Trim(s, "`\"")
return s
}
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
func SkillMdContent(ctx context.Context, skillName string) string {
// 1. 请求接口获取 SkillUserVO
skillResp, err := GetSkillUser(ctx, skillName)
if err != nil {
return ""
}
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
// 2. 下载 zip 文件
client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
}
zipBytes, err := downloadFile(client, fullUrl,
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
)
if err != nil {
return ""
}
// 3. 解压 zip 并提取所有 md 文件内容
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil {
return ""
}
if len(mdContents) == 0 {
return ""
}
// 4. 拼接所有 md 内容
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
}
for fileName, content := range mdContents {
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
builder.WriteString(content)
builder.WriteString("\n\n---\n\n")
}
return strings.TrimSpace(builder.String())
}
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
result := make(map[string]string)
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, err
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
for _, file := range reader.File {
if file.FileInfo().IsDir() {
continue
}
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
continue
}
rc, err := file.Open()
if err != nil {
continue
}
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
rc.Close()
if err != nil {
continue
}
if len(content) > 0 {
result[file.Name] = strings.TrimSpace(string(content))
}
}
return result, nil
}

View File

@@ -1,157 +0,0 @@
package gateway
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
// CreateTaskReq 创建任务请求
type CreateTaskReq struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Text string `json:"text"`
ErrorMsg string `json:"error_msg"`
}
// CreateGatewayTask 创建网关异步任务
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
fullURL := "model-gateway/task/createTask"
headers := util.ForwardHeaders(ctx)
var req CreateTaskReq
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
return "", err
}
return req.TaskId, nil
}
// GetTaskResultRes 任务结果响应
type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
State int `json:"state" dc:"任务状态"`
}
// QueryGatewayTaskState 查询网关任务状态
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
headers := util.ForwardHeaders(ctx)
var req GetTaskResultRes
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
return 0, err
}
return req.State, nil
}
// SkillUserVO 技能用户视图对象
type SkillUserVO struct {
Id int64 `json:"id,string"`
Name string `json:"name"`
Description string `json:"description"`
FileName string `json:"fileName"`
FileUrl string `json:"fileUrl"`
CreatedAt *gtime.Time `json:"createdAt"`
UpdatedAt *gtime.Time `json:"updatedAt"`
ImgAddressPrefix string `json:"imgAddressPrefix"`
}
// GetSkillUser 获取技能用户信息
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
headers := util.ForwardHeaders(ctx)
var resp SkillUserVO
var req struct{}
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
return nil, err
}
return &resp, nil
}
// SendCallbackReq 发送回调的请求体
type SendCallbackReq struct {
TaskId string `json:"taskId"`
Status string `json:"status"`
Messages *MultiRoundResult `json:"messages,omitempty"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
// SendCallback 向业务方发送回调
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
// 1. 检查回调地址
if composeTask.CallbackUrl == "" {
return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId)
}
// 2. 构造请求体
req := SendCallbackReq{
TaskId: composeTask.TaskId,
Status: composeTask.Status,
Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体
ErrorMsg: composeTask.ErrorMessage,
}
// 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx)
var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
composeTask.TaskId, composeTask.CallbackUrl, req.Messages)
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
}
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl)
return nil
}
// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult
func parseMessagesToResult(messages any) *MultiRoundResult {
if messages == nil {
return nil
}
var result MultiRoundResult
switch v := messages.(type) {
case *MultiRoundResult:
return v
case MultiRoundResult:
return &v
case string:
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil
}
case []byte:
if err := json.Unmarshal(v, &result); err != nil {
return nil
}
case map[string]any:
// 通过 JSON 序列化再反序列化
data, _ := json.Marshal(v)
if err := json.Unmarshal(data, &result); err != nil {
return nil
}
default:
data, err := json.Marshal(v)
if err != nil {
return nil
}
if err = json.Unmarshal(data, &result); err != nil {
return nil
}
}
return &result
}

53
service/headers.go Normal file
View File

@@ -0,0 +1,53 @@
package service
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// asyncCtx 固化异步执行所需的 token/user避免请求结束后丢失仅在“同请求内起 goroutine”有用
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// 兜底:从请求头拿
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

75
service/http_service.go Normal file
View File

@@ -0,0 +1,75 @@
package service
import (
"context"
"encoding/json"
"fmt"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/os/gtime"
)
// CreateTaskReq 创建任务请求
type CreateTaskReq struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Text string `json:"text"`
ErrorMsg string `json:"error_msg"`
}
// createGatewayTask 调用 model-gateway 异步任务并同步等待结果
func createGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
fullURL := "model-gateway/task/createTask"
headers := forwardHeaders(ctx)
var req CreateTaskReq
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
return "", err
}
return req.TaskId, nil
}
type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
State int `json:"state" dc:"任务状态"`
}
// queryGatewayTaskState 查询网关任务状态
func queryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
headers := forwardHeaders(ctx)
var req GetTaskResultRes
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
return 0, err
}
return req.State, nil
}
// SkillUserVO 技能用户视图对象
type SkillUserVO struct {
Id int64 `json:"id,string"`
Name string `json:"name"`
Description string `json:"description"`
FileName string `json:"fileName"`
FileUrl string `json:"fileUrl"` // html 后缀
CreatedAt *gtime.Time `json:"createdAt"`
UpdatedAt *gtime.Time `json:"updatedAt"`
ImgAddressPrefix string `json:"imgAddressPrefix"` // htmml 前缀
}
// GetSkillUser 根据 name 获取技能用户信息
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
headers := forwardHeaders(ctx)
var resp SkillUserVO
var req struct{}
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -1,204 +0,0 @@
package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName,
maxWindowSize,
availableWindow,
totalRounds,
totalRounds,
totalRounds,
outputJSON,
inputInfo,
totalRounds,
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
}
return util.MustMarshal(payload)
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(text)
}
return builder.String()
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -1,380 +0,0 @@
package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := getChatModel(ctx, userInfo.UserName)
if err != nil {
return nil, nil, err
}
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
}
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig)
if err != nil {
return fmt.Errorf("校验用户表单失败: %w", err)
}
if !isValid {
availableWindow := util.GetAvailableWindow(model.TokenConfig)
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
// 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending,
})
return err
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
}
return taskID, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{data},
}
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
if composeTask == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
}
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = parsePromptResult(req.Text)
case public.BuildTypeNode: // 节点构建解析
messages = parseNodeResult(req.Text)
default:
messages = req.Text
}
// 2. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
}
gateway.SendCallback(ctx, successTask)
}
}
return err
}
// parsePromptResult 解析提示词构建结果
func parsePromptResult(raw string) *dto.MultiRoundResult {
var wrapper map[string]any
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
contentStr, ok := wrapper["content"].(string)
if !ok || contentStr == "" {
return createDefaultResult(wrapper)
}
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
}
}
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{singleRound},
}
}
return createDefaultResult(map[string]any{"content": contentStr})
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
// parseNodeResult 解析节点构建结果
func parseNodeResult(raw string) *dto.MultiRoundResult {
var result map[string]any
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner
}
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{result},
}
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}

View File

@@ -1,283 +0,0 @@
package prompt
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/service/gateway"
)
const (
bytesPerKB = 1024
bytesPerMB = 1024 * 1024
)
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string)
if len(urls) == 0 {
return result
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
if url == "" || util.IsBannedExtension(url) {
continue
}
if util.IsZipExtension(url) {
mergeMap(result, fetchZipFileTexts(ctx, client, url))
continue
}
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text
}
}
return result
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
}
// fetchAndCleanFileContent 获取并清理文件内容
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
text, err := fetchFileContent(ctx, client, url)
if err != nil || text == "" {
return ""
}
return util.CleanSymbols(text)
}
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, url, maxSize)
if err != nil {
return result
}
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if shouldSkipZipEntry(file.Name) {
continue
}
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
result[url+"::"+file.Name] = text
}
}
return result
}
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
func shouldSkipZipEntry(fileName string) bool {
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
}
// extractZipEntryContent 提取 zip 条目内容
func extractZipEntryContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if !util.IsReadableContentType(http.DetectContentType(content)) {
return ""
}
text := util.CleanSymbols(string(content))
if text == "" {
return ""
}
return text
}
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
return body, nil
}
// fetchFileContent 获取单个文本文件内容
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
}
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
return strings.TrimSpace(string(body)), nil
}
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
func SkillMdContent(ctx context.Context, skillName string) string {
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
return ""
}
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, fullUrl, maxSize)
if err != nil {
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil || len(mdContents) == 0 {
return ""
}
return buildSkillMarkdown(skillResp, mdContents)
}
// buildSkillMarkdown 构建技能 Markdown 内容
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
}
for fileName, content := range mdContents {
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
builder.WriteString(content)
builder.WriteString("\n\n---\n\n")
}
return strings.TrimSpace(builder.String())
}
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
result := make(map[string]string)
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
continue
}
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
result[file.Name] = content
}
}
return result, nil
}
// isMarkdownFile 判断是否为 Markdown 文件
func isMarkdownFile(fileName string) bool {
return strings.HasSuffix(strings.ToLower(fileName), ".md")
}
// readMarkdownFileContent 读取 Markdown 文件内容
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if len(content) == 0 {
return ""
}
return strings.TrimSpace(string(content))
}
// createHTTPClient 创建 HTTP 客户端
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
return &http.Client{
Timeout: timeout,
}
}

View File

@@ -1,75 +0,0 @@
# Prompts-Core提示词核心服务
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
---
## 项目简介
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
### 核心价值
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
---
## 核心功能
### 1. 提示词构建引擎
#### 多模态支持
| 类型 | 说明 | 适用场景 |
|------|------|----------|
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
#### 构建模式
- **BuildType 1提示词构建**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
- **BuildType 2节点构建**:工作流路由决策,根据上下文选择节点 ID
#### 分批处理
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
### 2. 会话管理系统
- **双层存储**Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
### 3. 协议适配器
通过配置动态支持多种模型协议:
- 角色映射system/user/assistant → 目标协议角色
- 内容字段映射content → parts.text 等
- 消息顺序控制:灵活配置拼接顺序
- 请求模板渲染:支持占位符替换
### 4. 任务调度
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
- **重试机制**:可配置最大重试次数(默认 3 次)
- **超时保护**:默认 300 秒超时
---
## 技术架构
### 技术栈
| 组件 | 版本 | 用途 |
|------|------|------|
| Go | 1.26.0 | 编程语言 |
| GoFrame | v2.10.0 | Web 框架 |
| PostgreSQL | - | 关系型数据库 |
| Redis | - | 缓存与会话存储 |
| Consul | - | 服务注册与发现 |
| Jaeger | - | 分布式链路追踪 |
### 架构图

View File

@@ -1,291 +0,0 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
}
// Segment 消息片段
type Segment struct {
Type string `json:"type"`
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"`
Field string `json:"field"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
}
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString("System: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Role)
builder.WriteString(": ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString("User: ")
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.History {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
for _, seg := range ir.User {
builder.WriteString(seg.Content)
builder.WriteString("\n")
}
return builder.String()
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
return ir
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
return ir
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
return ir
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
// GetProtocolByProvider 根据 provider_name 获取协议配置
func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) {
entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: providerName,
Status: 1,
})
if err != nil || entity == nil {
return nil, err
}
return parseProtocol(entity), nil
}
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
}
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
return p
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
messages := mergeByOrder(ir, p.MergeOrder)
messages = mapRoles(messages, p.RoleMapping)
messages = mapContent(messages, p.ContentMapping)
return buildRequest(messages, p, chatModel), nil
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
}
return messages
}
// mapRoles 角色映射
func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any {
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
msg[cm.Field] = content
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}

View File

@@ -1,145 +0,0 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
)
const (
redisKeyPrefix = "chat:session:%s"
)
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := formatRedisKey(sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err
}
return nil
}
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := formatRedisKey(sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
}
sessions := parseRedisSessions(ctx, result.Strings())
reverseSlice(sessions)
return sessions, nil
}
// parseRedisSessions 解析Redis会话数据
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
var sessions []map[string]any
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
return sessions
}
// reverseSlice 反转切片
func reverseSlice(s []map[string]any) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
return []map[string]any{}, nil
}
return flattenHistoryMessages(historyData), nil
}
// flattenHistoryMessages 扁平化历史消息
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
var messages []map[string]any
for _, round := range historyData {
appendMessagesFromField(round, "requestContent", &messages)
appendMessagesFromField(round, "responseContent", &messages)
}
return messages
}
// appendMessagesFromField 从指定字段追加消息
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]interface{})
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
*messages = append(*messages, msg)
}
}
}

View File

@@ -1,165 +0,0 @@
package prompt
import (
"context"
"fmt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// SessionCallback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
result["role"] = "assistant"
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
return nil, err
}
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
ResponseContent: response,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err)
}
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
return session, nil
}
// saveSessionToRedis 保存会话到Redis
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
}
return nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
messages := extractMessagesFromSessions(sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
}
return messages
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
*messages = append(*messages, m)
}
}
}
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
*messages = append(*messages, m)
}
}
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
}

View File

@@ -1,135 +0,0 @@
package prompt
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}
availableWindow := util.GetAvailableWindow(model.TokenConfig)
batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig)
if len(batches) <= 1 {
return req, 1, nil
}
newUserForm := buildBatchedUserForm(batches)
newReq := *req
newReq.UserForm = newUserForm
g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)",
len(req.UserForm), len(batches))
return &newReq, len(batches), nil
}
// buildBatchedUserForm 构建分批后的用户表单
func buildBatchedUserForm(batches [][]map[string]any) []map[string]any {
newUserForm := make([]map[string]any, 0, len(batches))
for i, batch := range batches {
combinedText := combineBatchText(batch)
newUserForm = append(newUserForm, map[string]any{
"batch_index": i + 1,
"total_batches": len(batches),
"text": combinedText,
"item_count": len(batch),
})
}
return newUserForm
}
// combineBatchText 合并批次中的所有文本(合并所有字段的值)
func combineBatchText(batch []map[string]any) string {
var builder strings.Builder
for j, item := range batch {
itemText := getItemText(item)
if itemText == "" {
continue
}
if j > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(itemText)
}
return builder.String()
}
// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批
func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any {
if len(userForm) == 0 {
return [][]map[string]any{}
}
batches := make([][]map[string]any, 0)
currentBatch := make([]map[string]any, 0)
currentTokens := 0
for i, item := range userForm {
itemText := getItemText(item)
itemTokens := util.CalculateTokens(itemText, tokenConfig)
// 单个元素超过窗口,单独成一批
if itemTokens > maxTokens {
if len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
batches = append(batches, []map[string]any{item})
continue
}
// 判断是否需要新开一批
if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
currentBatch = make([]map[string]any, 0)
currentTokens = 0
}
currentBatch = append(currentBatch, item)
currentTokens += itemTokens
// 最后一批
if i == len(userForm)-1 && len(currentBatch) > 0 {
batches = append(batches, currentBatch)
}
}
return batches
}
// getItemText 获取 item 中的所有文本内容(合并所有字段的值)
func getItemText(item map[string]any) string {
if len(item) == 0 {
return ""
}
var parts []string
for key, value := range item {
// 跳过分批时添加的元数据字段
if key == "batch_index" || key == "total_batches" || key == "item_count" {
continue
}
parts = append(parts, fmt.Sprintf("%v", value))
}
return strings.Join(parts, "\n")
}

92
service/prompt_service.go Normal file
View File

@@ -0,0 +1,92 @@
package service
import (
"context"
"encoding/json"
"errors"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
var Prompt = &promptService{}
type promptService struct{}
func (s *promptService) Create(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
// promptInfo 兜底校验:必须可序列化为 JSON
if req.PromptInfo == nil {
return nil, errors.New("promptInfo不能为空")
}
if _, err := json.Marshal(req.PromptInfo); err != nil {
return nil, errors.New("promptInfo不是合法JSON")
}
if req.ResponseJsonSchema == nil {
return nil, errors.New("responseJsonSchema不能为空")
}
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
return nil, errors.New("responseJsonSchema不是合法JSON")
}
m := &entity.PromptConfig{
ModelTypeId: req.ModelTypeId,
ModelType: req.ModelType,
PromptInfo: req.PromptInfo,
ResponseJsonSchema: req.ResponseJsonSchema,
Enabled: 1,
Version: req.Version,
}
id, err := dao.Prompt.Insert(ctx, m)
if err != nil {
return nil, err
}
return &dto.CreatePromptRes{ID: id}, nil
}
func (s *promptService) Update(ctx context.Context, req *dto.UpdatePromptReq) error {
data := map[string]any{}
if req.ModelTypeId != nil && *req.ModelTypeId > 0 {
data[entity.PromptConfigCol.ModelTypeId] = *req.ModelTypeId
}
if req.ModelType != nil && *req.ModelType != "" {
data[entity.PromptConfigCol.ModelType] = *req.ModelType
}
if req.PromptInfo != nil {
if _, err := json.Marshal(req.PromptInfo); err != nil {
return errors.New("promptInfo不是合法JSON")
}
data[entity.PromptConfigCol.PromptInfo] = req.PromptInfo
}
if req.ResponseJsonSchema != nil {
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
return errors.New("responseJsonSchema不是合法JSON")
}
data[entity.PromptConfigCol.ResponseJsonSchema] = req.ResponseJsonSchema
}
if req.Enabled != nil {
data[entity.PromptConfigCol.Enabled] = *req.Enabled
}
if req.Version != nil {
data[entity.PromptConfigCol.Version] = *req.Version
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.Prompt.UpdateByID(ctx, req.ID, data)
return err
}
func (s *promptService) Delete(ctx context.Context, id int64) error {
_, err := dao.Prompt.DeleteByID(ctx, id)
return err
}
func (s *promptService) Get(ctx context.Context, id int64) (*entity.PromptConfig, error) {
return dao.Prompt.GetByID(ctx, id)
}
func (s *promptService) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
return dao.Prompt.List(ctx, pageNum, pageSize, modelTypeID, modelTypeLike)
}

View File

@@ -0,0 +1,114 @@
package service
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// ==================== Redis 操作 ====================
// saveToRedis 保存会话数据到Redis
func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := fmt.Sprintf("chat:session:%s", sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := fmt.Sprintf("chat:session:%s", sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
}
var sessions []map[string]any
values := result.Strings()
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
// 反转Redis 最新在前 → 时间正序)
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := s.getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
return []map[string]any{}, nil
}
var messages []map[string]any
for _, round := range historyData {
if reqMsgs, ok := round["requestContent"].([]interface{}); ok {
for _, m := range reqMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
if respMsgs, ok := round["responseContent"].([]interface{}); ok {
for _, m := range respMsgs {
if msg, ok := m.(map[string]interface{}); ok {
messages = append(messages, msg)
}
}
}
}
return messages, nil
}

112
service/session_service.go Normal file
View File

@@ -0,0 +1,112 @@
package service
import (
"context"
"fmt"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Session = &sessionService{}
type sessionService struct{}
func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
// 1. 解析AI返回的文本
result, err := parseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 2. 更新数据库
result["role"] = "assistant"
_, err = dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 3. 获取当前轮次完整数据
session, err := dao.ComposeSession.GetById(ctx, req.EpicycleId)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err
}
// 4. 转换 json 并存入 Redis
requestMessages := convertToMessages(session.RequestContent)
responseMessages := convertToMessages(session.ResponseContent)
if err = s.saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return nil, err
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &beans.ResponseEmpty{}, nil
}
// GetHistoryMessages 获取历史信息
func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿
redisHistory, err := s.GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
// 2. Redis 没有 → fallback DB
sessions, err := dao.ComposeSession.GetListBySessionId(ctx, sessionId, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
var messages []map[string]any
for _, session := range sessions {
// request
reqMsgs := convertToMessages(session.RequestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
messages = append(messages, m)
}
}
// response
respMsgs := convertToMessages(session.ResponseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
messages = append(messages, m)
}
}
// 3. 回写 Redis
for _, session := range sessions {
reqMsgs := convertToMessages(session.RequestContent)
respMsgs := convertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = s.saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
return messages, nil
}

65
service/utils.go Normal file
View File

@@ -0,0 +1,65 @@
// utils 工具函数
package service
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ============================================
// json 相关处理
// ============================================
// parseOutput 解析模型输出为 JSON 格式
func parseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
func convertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
// 1. 如果有 messages
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
// 2. 否则当成单条 message
return []map[string]any{
j.Map(),
}
}
// isMessageValid 校验推理结果是否合法
func isMessageValid(message map[string]any) bool {
if message == nil {
return false
}
return true
}
func formToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
func mustMarshal(v any) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}

View File

@@ -1,4 +1,48 @@
-- prompts_compose_task 提示词任务记录表
-- prompts-core 核心表pgsql
-- 说明字段风格尽量与参考项目一致tenant/creator/updater/created_at/updated_at/deleted_at
-- prompts_model_prompt 模型提示词配置表
CREATE TABLE IF NOT EXISTS prompts_model_prompt (
-- 基础字段(与 common/db/gfdb 的 Hook 约定保持一致)
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL, -- 创建人
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updater VARCHAR(64) NOT NULL, -- 更新人
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
deleted_at TIMESTAMP(6), -- 删除时间(软删)
-- 业务字段(按你当前的最小字段集)
model_type_id INT NOT NULL DEFAULT 0, -- 模型分类ID
model_type VARCHAR(64) NOT NULL, -- 模型类别
prompt_info JSONB NOT NULL DEFAULT '{}'::jsonb, -- 提示词信息JSON
response_json_schema JSONB NOT NULL DEFAULT '{}'::jsonb, -- 模型返回表单 JSON 格式约束
enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用1启用/0禁用
version VARCHAR(64) NOT NULL DEFAULT '' -- 版本号(预留)
);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_tenant_id ON prompts_model_prompt(tenant_id);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type_id ON prompts_model_prompt(model_type_id);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type ON prompts_model_prompt(model_type);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_enabled ON prompts_model_prompt(enabled);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_deleted_at ON prompts_model_prompt(deleted_at);
COMMENT ON TABLE prompts_model_prompt IS '模型提示词配置表';
COMMENT ON COLUMN prompts_model_prompt.id IS '主键ID非自增';
COMMENT ON COLUMN prompts_model_prompt.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_model_prompt.creator IS '创建人';
COMMENT ON COLUMN prompts_model_prompt.created_at IS '创建时间';
COMMENT ON COLUMN prompts_model_prompt.updater IS '更新人';
COMMENT ON COLUMN prompts_model_prompt.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_model_prompt.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_model_prompt.model_type_id IS '模型分类ID';
COMMENT ON COLUMN prompts_model_prompt.model_type IS '模型类别';
COMMENT ON COLUMN prompts_model_prompt.prompt_info IS '提示词信息JSON';
COMMENT ON COLUMN prompts_model_prompt.response_json_schema IS '模型返回表单 JSON 格式约束';
COMMENT ON COLUMN prompts_model_prompt.enabled IS '是否启用1启用/0禁用';
COMMENT ON COLUMN prompts_model_prompt.version IS '版本号(预留)';
-- prompts_compose_task 拼接提示词任务记录表
CREATE TABLE IF NOT EXISTS prompts_compose_task (
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
@@ -11,38 +55,27 @@ CREATE TABLE IF NOT EXISTS prompts_compose_task (
task_id VARCHAR(64) NOT NULL,
model_name VARCHAR(128) NOT NULL DEFAULT '',
skill_name VARCHAR(128) NOT NULL DEFAULT '',
build_type INT NOT NULL DEFAULT 0,
callback_url VARCHAR(512) NOT NULL DEFAULT '',
gateway_state INT NOT NULL DEFAULT 0,
limit_words INT NOT NULL DEFAULT 0,
request_payload JSONB NOT NULL DEFAULT '{}'::jsonb,
result_text TEXT NOT NULL DEFAULT '',
messages JSONB NOT NULL DEFAULT '{}'::jsonb,
messages JSONB NOT NULL DEFAULT '[]'::jsonb,
status VARCHAR(32) NOT NULL DEFAULT 'pending',
error_message TEXT NOT NULL DEFAULT '',
oss_file VARCHAR(1024) NOT NULL DEFAULT '',
file_type VARCHAR(64) NOT NULL DEFAULT ''
);
-- 索引
CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task(deleted_at);
-- 注释
COMMENT ON TABLE prompts_compose_task IS '提示词任务记录表';
COMMENT ON COLUMN prompts_compose_task.id IS '主键ID';
COMMENT ON COLUMN prompts_compose_task.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_compose_task.creator IS '创建人';
COMMENT ON COLUMN prompts_compose_task.created_at IS '创建时间';
COMMENT ON COLUMN prompts_compose_task.updater IS '更新人';
COMMENT ON COLUMN prompts_compose_task.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_compose_task.deleted_at IS '删除时间(软删)';
COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表';
COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID';
COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称';
COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称';
COMMENT ON COLUMN prompts_compose_task.build_type IS '构建类型0默认/1提示词构建/2节点构建';
COMMENT ON COLUMN prompts_compose_task.callback_url IS '回调地址';
COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态0排队/1执行/2成功/3失败/4已下载';
COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数';
COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容';
COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果';
COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages';
@@ -51,11 +84,9 @@ COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息';
COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址';
COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型';
-- prompts_compose_session 提示词历史会话表
CREATE TABLE IF NOT EXISTS prompts_compose_session (
id BIGINT NOT NULL,
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -68,12 +99,12 @@ CREATE TABLE IF NOT EXISTS prompts_compose_session (
response_content JSONB NOT NULL DEFAULT '{}'::jsonb,
remark VARCHAR(500) NOT NULL DEFAULT ''
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_session_id ON prompts_compose_session(session_id);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_deleted_at ON prompts_compose_session(deleted_at);
-- 注释
COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表';
COMMENT ON COLUMN prompts_compose_session.id IS '主键ID';
COMMENT ON COLUMN prompts_compose_session.id IS '主键ID(非自增)';
COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_compose_session.creator IS '创建人';
COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间';
@@ -84,51 +115,3 @@ COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID';
COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容JSON格式';
COMMENT ON COLUMN prompts_compose_session.response_content IS '返回内容JSON格式';
COMMENT ON COLUMN prompts_compose_session.remark IS '备注';
-- prompts_provider_protocol 模型协议映射配置表
CREATE TABLE IF NOT EXISTS prompts_provider_protocol (
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6),
provider_name VARCHAR(64) NOT NULL DEFAULT '',
target_field VARCHAR(64) NOT NULL DEFAULT '',
merge_order JSONB NOT NULL DEFAULT '[]'::jsonb,
role_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
content_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
request_template JSONB NOT NULL DEFAULT '{}'::jsonb,
system_prompt_template TEXT NOT NULL DEFAULT '',
user_prompt_template TEXT NOT NULL DEFAULT '',
status INT NOT NULL DEFAULT 1,
remark VARCHAR(500) NOT NULL DEFAULT ''
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_provider_name ON prompts_provider_protocol(provider_name);
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_status ON prompts_provider_protocol(status);
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_deleted_at ON prompts_provider_protocol(deleted_at);
-- 注释
COMMENT ON TABLE prompts_provider_protocol IS '模型协议映射配置表';
COMMENT ON COLUMN prompts_provider_protocol.id IS '主键ID';
COMMENT ON COLUMN prompts_provider_protocol.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_provider_protocol.creator IS '创建人';
COMMENT ON COLUMN prompts_provider_protocol.created_at IS '创建时间';
COMMENT ON COLUMN prompts_provider_protocol.updater IS '更新人';
COMMENT ON COLUMN prompts_provider_protocol.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_provider_protocol.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_provider_protocol.provider_name IS '运营商名称openai/deepseek/qwen/anthropic/gemini等';
COMMENT ON COLUMN prompts_provider_protocol.target_field IS '目标字段messages/contents/prompt';
COMMENT ON COLUMN prompts_provider_protocol.merge_order IS 'Prompt IR 拼接顺序system/history/user';
COMMENT ON COLUMN prompts_provider_protocol.role_mapping IS '角色映射system/user/assistant -> provider role';
COMMENT ON COLUMN prompts_provider_protocol.content_mapping IS '内容字段映射content/parts.text等';
COMMENT ON COLUMN prompts_provider_protocol.capabilities IS '协议能力配置system/history/tools/stream等支持情况';
COMMENT ON COLUMN prompts_provider_protocol.request_template IS '请求模板JSON结构模板';
COMMENT ON COLUMN prompts_provider_protocol.system_prompt_template IS '系统提示词模板';
COMMENT ON COLUMN prompts_provider_protocol.status IS '状态1启用/0禁用';
COMMENT ON COLUMN prompts_provider_protocol.remark IS '备注';