Compare commits

23 Commits

Author SHA1 Message Date
0d52b631b9 refactor(task): 重构任务服务和数据结构 2026-06-12 15:29:06 +08:00
c22d578e1a fix(task): 修复任务状态更新和超时处理问题 2026-06-11 11:27:15 +08:00
df26329836 feat(session): 重构会话服务支持节点维度的Redis缓存管理 2026-06-10 16:48:35 +08:00
40abf0f606 ci/cd调整 2026-06-10 16:32:42 +08:00
b69e7386e2 refactor(prompts-core): 重构代码结构和优化工具函数 2026-06-10 14:51:25 +08:00
1c1db7e30c feat(prompt): 实现历史消息注入功能和协议配置优化
- 在 handleCallbackSuccess 函数中新增获取协议配置逻辑
- 实现历史消息获取并在 rounds 中注入历史消息
- 添加 InjectHistory 函数实现按协议顺序合并历史消息
- 在 GetPromptText 接口中集成历史消息注入测试
- 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string
- 新增 Capabilities 字段支持最大 token 配置
- 修改 renderTemplate 函数接收协议对象参数
- 优化会话历史存储逻辑,提取用户消息内容进行记录
- 移除无用的注释代码 handleCallbackSuccess 处理回调成功
2026-06-10 10:16:58 +08:00
78114f99c7 feat(session): 重构会话管理和消息存储功能 2026-06-09 15:46:09 +08:00
9410199fbe feat(session): 重构会话管理和Redis缓存机制 2026-06-09 14:00:01 +08:00
1f9a2b9b5f Merge remote-tracking branch 'origin/dev' into dev 2026-06-08 18:02:27 +08:00
e1461cf0f0 feat: 重构异步模型字段并更新依赖 2026-06-08 18:01:54 +08:00
aa7804656f ci/cd调整 2026-06-08 15:37:12 +08:00
5494a0c480 ci/cd调整 2026-06-08 13:44:54 +08:00
qhd
ee6677c1f8 fix: 修复响应体解析逻辑并统一结构包装 2026-06-05 11:48:27 +08:00
de70d33115 refactor(prompt): 重构提示词构建服务和回调处理 2026-06-05 11:00:05 +08:00
b2cad4cac2 refactor(model-gateway): 重构代码结构并优化数据库查询 2026-06-03 18:37:18 +08:00
05cf1b9828 refactor(model): 移除异步模型相关实体和数据访问对象 2026-06-03 13:31:15 +08:00
3fa2896fc3 refactor(util): 重构映射工具函数并优化异步任务轮询逻辑 2026-06-03 13:30:39 +08:00
c11a9ad5c8 chore(deps): 初始化项目依赖配置 2026-06-02 20:28:06 +08:00
0bbaddace0 refactor(asynch): 重构异步模型配置和队列管理 2026-06-02 20:26:46 +08:00
1bcf8f6e10 feat(model): 添加流式配置支持并优化响应处理 2026-05-30 22:08:46 +08:00
55eb436639 refactor(service): 重构服务模块结构并优化模型配置 2026-05-29 17:54:19 +08:00
d74559ae74 refactor(task): 重构异步任务处理流程 2026-05-27 09:36:26 +08:00
2548ffc7ac feat: 新增模型扩展映射与查询配置字段 2026-05-23 18:08:09 +08:00
34 changed files with 1196 additions and 1461 deletions

24
Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
# 阶段1: 构建
FROM golang:alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
ENV GO111MODULE=on
ENV GOPROXY=https://goproxy.cn,direct
ENV CGO_ENABLED=0
ENV GOTOOLCHAIN=auto
WORKDIR /build
COPY . .
RUN go mod download && go mod tidy
RUN go build -ldflags="-s -w" -o main ./main.go
EXPOSE 3009
CMD ["./main"]

View File

@@ -2,20 +2,14 @@ package util
import (
"context"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
// address 格式如 ":3009",去掉冒号
if strings.HasPrefix(address, ":") {
return address[1:]
}
return "8080"
// GetServerName 获取服务名称
func GetServerName(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "server.name", "").String()
}
// GetModelPrompt 获取请求模型的提示词
@@ -28,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string {
func GetBuildPrompt(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
}
// GetMaxRounds 获取最大轮数配置
func GetMaxRounds(ctx context.Context) int {
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
}
// GetExpireMinutes 获取过期时间配置
func GetExpireMinutes(ctx context.Context) int {
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
}

View File

@@ -3,7 +3,7 @@ package util
import (
"context"
"gitea.com/red-future/common/utils"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)

View File

@@ -1,151 +1,81 @@
package util
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ParseOutput 解析模型输出为 JSON 格式
func ParseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
return messages
}
return j.Map(), nil
}
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
consult := gconv.Interfaces(req["consult"])
if len(consult) == 0 {
return messages
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
targetPath := gconv.String(extendMapping["target_content_path"])
templates := gconv.Map(extendMapping["attachment_templates"])
if targetPath == "" || len(templates) == 0 {
return messages
}
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
msgJson := gjson.New(messages)
// rounds 路径修正
if !msgJson.Get("rounds.0").IsNil() {
targetPath = "rounds.0." + targetPath
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshal 将对象序列化为 JSON 字符串,失败时返回空对象
func MustMarshal(v any) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
func GvarToMap(v *gvar.Var) map[string]any {
if v == nil || v.IsNil() {
return nil
}
result := make(map[string]any)
// 方法1尝试获取 map 值
if m := v.Map(); len(m) > 0 {
return m
}
// 方法2尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), &result)
if len(result) > 0 {
return result
// 遍历追加
for _, item := range consult {
itemJson := gjson.New(item)
itemType := itemJson.Get("type").String()
tmpl := gconv.Map(templates[itemType])
if itemType == "" || len(tmpl) == 0 {
continue
}
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
if attachment == nil {
continue
}
idx := len(msgJson.Get(targetPath).Array())
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
}
// 方法3尝试获取 interface 再转换
if val := v.Val(); val != nil {
switch val.(type) {
return msgJson.Map()
}
func buildAttachment(tmpl map[string]any, url string) map[string]any {
typ := gconv.String(tmpl["type"])
if typ == "" || url == "" {
return nil
}
body := gconv.Map(tmpl["body"])
fillEmptyInPlace(body, url)
return map[string]any{
"type": typ,
typ: body,
}
}
func fillEmptyInPlace(m map[string]any, value string) {
for k, v := range m {
switch vv := v.(type) {
case string:
if vv == "" {
m[k] = value
}
case map[string]any:
return val.(map[string]any)
default:
data, _ := json.Marshal(val)
json.Unmarshal(data, &result)
fillEmptyInPlace(vv, value)
}
}
return result
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}

57
common/util/mapping.go Normal file
View File

@@ -0,0 +1,57 @@
package util
import (
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
val := gjson.New(payload).Get(path)
if !val.IsNil() {
_ = jsonObj.Set(path, val.Val())
} else if defaultValue != nil {
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// ExtractUserText 从 messages 中提取所有 user 文本
func ExtractUserText(messages map[string]any) map[string]any {
msgJson := gjson.New(messages)
msgs := msgJson.Get("rounds.0.messages")
if msgs.IsNil() {
msgs = msgJson.Get("messages")
}
var texts []string
for _, m := range msgs.Array() {
msg := gjson.New(m)
if msg.Get("role").String() != "user" {
continue
}
content := msg.Get("content").Val()
switch c := content.(type) {
case string:
texts = append(texts, c)
case []any:
for _, item := range c {
if m, ok := item.(map[string]any); ok {
if t := gconv.String(m["text"]); t != "" {
texts = append(texts, t)
}
}
}
}
}
return map[string]any{
"role": "user",
"content": strings.Join(texts, "\n"),
}
}

View File

@@ -1,130 +0,0 @@
package util
import (
"context"
"net"
"strings"
"github.com/gogf/gf/v2/frame/g"
)
// GetLocalIP 获取本机有效的局域网 IPv4 地址
func GetLocalIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
var validIPs []string
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP
if isIPValid(ip) {
validIPs = append(validIPs, ip.String())
}
}
// 优先返回非 169.254.x.x 的 IP
for _, ip := range validIPs {
if !strings.HasPrefix(ip, "169.254.") {
return ip
}
}
// 其次返回 169.254.x.x最后的选择
if len(validIPs) > 0 {
return validIPs[0]
}
return "127.0.0.1"
}
// isIPValid 判断 IP 是否有效
func isIPValid(ip net.IP) bool {
// 不是 loopback (127.0.0.1)
if ip.IsLoopback() {
return false
}
// 是 IPv4
if ip.To4() == nil {
return false
}
// 不是链路本地地址 (169.254.0.0/16)
if ip[0] == 169 && ip[1] == 254 {
return false
}
// 不是组播地址
if ip.IsMulticast() {
return false
}
// 不是未指定地址 (0.0.0.0)
if ip.IsUnspecified() {
return false
}
return true
}
// GetLocalAddress 获取局域网地址IP:端口)
func GetLocalAddress(ctx context.Context) string {
ip := GetLocalIP()
port := GetServerPort(ctx)
if port == "80" || port == "443" {
return ip
}
return ip + ":" + port
}
// GetSchemaFromRequest 从当前请求中获取协议http/https
func GetSchemaFromRequest(ctx context.Context) string {
r := g.RequestFromCtx(ctx)
if r == nil {
return "http"
}
// 1. 代理场景X-Forwarded-Proto
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
// 2. 代理场景X-Forwarded-Scheme
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
return proto
}
// 3. TLS 连接(直接 HTTPS
if r.TLS != nil {
return "https"
}
// 4. 默认 HTTP这行很重要
return "http" // ← 确保有这行
}
// GetLocalBaseURL 获取局域网基础 URL动态协议 + IP + 端口)
func GetLocalBaseURL(ctx context.Context) string {
schema := GetSchemaFromRequest(ctx)
addr := GetLocalAddress(ctx)
return schema + "://" + addr
}
// GetCallbackURL 获取回调地址(完整 URL
func GetCallbackURL(ctx context.Context, path string) string {
baseURL := GetLocalBaseURL(ctx)
// 确保 path 以 / 开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}

View File

@@ -112,41 +112,3 @@ nodePrompts: |
%s
上下文内容:
%s
#你是专业的JSON结构生成专家必须严格遵守以下全部规则。
# 【强制规则】
# 必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
# 完整读取UserForm所有字段不得忽略任何字段
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中
# 理解全部语义后再输出,禁止断章取义;
# UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
# 【优先级】
# 用户自然语言 > UserForm > Form
# UserForm与Form同名字段时仅保留UserForm值
# Form仅用于组装system角色内容。
# 【表单处理】
# Form系统提示词、默认参数、基础配置 → 专属填充system角色
# UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
# 自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和用户没有相关数量必须默认1
# 图片尺寸为空时自动填充size=1024*1024。
# 【结构铁律】
# 严格沿用固定输出结构,不增删字段或修改层级;
# messages元素必须按结构返回
# 禁止将role对象转为字符串、禁止嵌套错乱
# 输出纯净JSON无多余转义符、无换行符、无额外字符
# 所有括号、引号必须成对闭合保证JSON合法。
# 【参数赋值】
# model固定沿用传入值
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充
# 从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
# 【输出要求】
# 仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
# 完整合UserForm全部字段语义到user描述
# 生成后自检JSON语法、结构、数量错误则自动重新生成。
# 【输出结构】
# %s
# 【完整输入信息】
# %s
# 直接输出最终JSON

View File

@@ -9,4 +9,9 @@ const (
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建
BuildTypeStruct = 3 //结构构建
)
const (
ModelTypeInference = 100 // 推理模型
)

View File

@@ -3,7 +3,6 @@ package controller
import (
"context"
"prompts-core/model/dto"
promptService "prompts-core/service/prompt"
)

View File

@@ -1,18 +1,36 @@
// ============================================
// controller/session.go
// ============================================
package controller
import (
"context"
"prompts-core/model/dto"
promptService "prompts-core/service/prompt"
"prompts-core/model/dto"
sessionService "prompts-core/service/session"
)
type session struct{}
// Session 提示词会话控制器
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return promptService.SessionCallback(ctx, req)
return sessionService.Callback(ctx, req)
}
// GetHistoryList 获取历史列表(前端列表)
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
return sessionService.GetHistoryList(ctx, req)
}
// DeleteMessages 批量删除消息
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
return sessionService.DeleteMessages(ctx, req)
}
// DeleteSession 删除整个会话
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
return sessionService.DeleteSession(ctx, req)
}

View File

@@ -5,8 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var ComposeSession = &composeSessionDao{}
@@ -15,13 +14,8 @@ type composeSessionDao struct{}
// Insert 插入
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
var m = new(entity.ComposeSession)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Insert(m)
Insert(req)
if err != nil {
return
}
@@ -69,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.Creator, req.Creator).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Fields(fields).One()
if err != nil {
@@ -86,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.Creator, req.Creator).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Delete()
if err != nil {
@@ -93,3 +89,36 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
}
return r.RowsAffected()
}
// ListByIds 根据 ID 列表批量查询
func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) {
if len(ids) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
WhereIn(entity.ComposeSessionCol.Id, ids).
Where(entity.ComposeSessionCol.Creator, creator).
Where(entity.ComposeSessionCol.SessionId, sessionId).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// DeleteByIds 批量删除编排会话
func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
WhereIn(entity.ComposeSessionCol.Id, ids).
Where(entity.ComposeSessionCol.Creator, creator).
Where(entity.ComposeSessionCol.SessionId, sessionId).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}

View File

@@ -5,7 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)

View File

@@ -1,28 +0,0 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
)
var Model = &modelDao{}
type modelDao struct{}
// Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}

View File

@@ -5,7 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)

16
go.mod
View File

@@ -1,17 +1,12 @@
module prompts-core
go 1.26.0
go 1.26.1
require (
gitea.com/red-future/common v0.0.19
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
github.com/gogf/gf/v2 v2.10.0
)
require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
gitea.redpowerfuture.com/red-future/common v0.0.23
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
github.com/gogf/gf/v2 v2.10.2
)
require (
@@ -68,7 +63,6 @@ require (
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/redis/go-redis/v9 v9.12.1 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/gjson v1.19.0
github.com/tiger1103/gfast-token v1.0.10 // indirect
github.com/vcaesar/cedar v0.30.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect

22
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -77,16 +77,16 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 h1:39+jbTenm7KBj4hO2C8ANAxVHpX/7OuRDs1VcGC9ylA=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0/go.mod h1:B0s0fVzn0W220E8UTpSGzrrGKsop5KcB90twBeLCiz0=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0 h1:N/F9CuDdUZLoM1nVRqrDE/33pDZuhVxpNY4wYdeIaBs=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0/go.mod h1:x6uoJGfZOtirIRQls8xUlYzC6f7T/eULPUa9er368X0=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 h1:u8EpP24GkprogROnJ7htMov9Fc66pTP1eVYrWxiCYOs=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2/go.mod h1:GmvM3r8GVByVMi4RD2+MCs5+CfxVXPMeT8mVDkAaAXE=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 h1:iTQegT+lEg/wDKvj2mi3W1wrdrwFarjokf88EXVVgu4=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2/go.mod h1:ZRw3GNz5cq4uYrW4TPSVyrYWaoqzujKdWro/AOcGBaE=
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs=
github.com/gogf/gf/v2 v2.10.0/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
github.com/gogf/gf/v2 v2.10.2 h1:46IO0Uc8e85/FqdftJFskfDejJLBL0JBnGS5qOftUu8=
github.com/gogf/gf/v2 v2.10.2/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@@ -288,12 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=

View File

@@ -7,9 +7,9 @@ import (
"prompts-core/controller"
"syscall"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
_ "gitea.com/red-future/common/swagger"
"gitea.redpowerfuture.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/jaeger"
_ "gitea.redpowerfuture.com/red-future/common/swagger"
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
"github.com/gogf/gf/v2/frame/g"

View File

@@ -6,39 +6,32 @@ type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
NodeId string `p:"nodeId" json:"nodeId" dc:"节点ID"`
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
Form map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
Form []map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
}
// ConsultItem 单个附件
type ConsultItem struct {
Type string `json:"type" dc:"附件类型image/video/audio"`
Url string `json:"url" dc:"附件地址"`
}
type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
}
/*
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
*/
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
Text string `json:"text" dc:"文本结果"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
}
type CallbackRes struct {
@@ -50,11 +43,11 @@ type GetComposeTaskReq struct {
}
type GetComposeTaskRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages map[string]any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}

View File

@@ -2,11 +2,79 @@ package dto
import "github.com/gogf/gf/v2/frame/g"
type SessionCallbackReq struct {
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
Text string `json:"text" dc:"文本结果"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
// HistoryRound 一轮对话
type HistoryRound struct {
Id int64 `json:"id" dc:"记录ID"`
SessionId string `json:"sessionId" dc:"会话ID"`
NodeId string `json:"nodeId" dc:"节点ID"`
User map[string]any `json:"user" dc:"用户消息"`
Assistant map[string]any `json:"assistant" dc:"助手回复"`
CreatedAt string `json:"createdAt" dc:"创建时间"`
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
}
type SessionCallbackRes struct {
// SessionCallbackReq 会话回调请求
type SessionCallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
}
// SessionCallbackRes 会话回调响应
type SessionCallbackRes struct {
Status bool `json:"status" dc:"状态"`
SessionId string `json:"sessionId" dc:"会话ID"`
}
// GetHistoryListReq 获取历史列表请求(前端)
type GetHistoryListReq struct {
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
Page int `json:"page" d:"1" dc:"页码"`
Size int `json:"size" d:"10" dc:"每页条数"`
}
// GetHistoryListRes 获取历史列表响应
type GetHistoryListRes struct {
List []HistoryRound `json:"list" dc:"历史列表"`
Total int `json:"total" dc:"总数"`
}
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
type GetHistoryMessagesReq struct {
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
NodeId string `json:"nodeId" dc:"节点ID"`
}
// GetHistoryMessagesRes 获取历史消息响应
type GetHistoryMessagesRes struct {
Messages []FlatMessage `json:"messages"`
}
type FlatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// DeleteMessagesReq 批量删除消息请求
type DeleteMessagesReq struct {
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
}
// DeleteMessagesRes 批量删除消息响应
type DeleteMessagesRes struct {
Ok bool `json:"ok" dc:"是否成功"`
}
// DeleteSessionReq 删除整个会话请求
type DeleteSessionReq struct {
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
}
// DeleteSessionRes 删除整个会话响应
type DeleteSessionRes struct {
Ok bool `json:"ok" dc:"是否成功"`
}

View File

@@ -1,94 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig any `orm:"token_config" json:"tokenConfig"`
}
type asynchModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
Prompt string
IsPrivate string
IsChatModel string
ApiKey string
Enabled string
MaxConcurrency string
QueueLimit string
TimeoutSeconds string
ExpectedSeconds string
RetryTimes string
RetryQueueMaxSecs string
AutoCleanSeconds string
Remark string
IsOwner string
OperatorName string
TokenConfig string
}
var AsynchModelCol = asynchModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
Prompt: "prompt",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
QueueLimit: "queue_limit",
TimeoutSeconds: "timeout_seconds",
ExpectedSeconds: "expected_seconds",
RetryTimes: "retry_times",
RetryQueueMaxSecs: "retry_queue_max_seconds",
AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
}

View File

@@ -1,18 +1,20 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
SessionId string `orm:"session_id" json:"sessionId"`
NodeId string `orm:"node_id" json:"nodeId"`
RequestContent map[string]any `orm:"request_content" json:"requestContent"`
ResponseContent map[string]any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
}
type composeSessionCol struct {
beans.SQLBaseCol
SessionId string
NodeId string
RequestContent string
ResponseContent string
Remark string
@@ -21,6 +23,7 @@ type composeSessionCol struct {
var ComposeSessionCol = composeSessionCol{
SQLBaseCol: beans.DefSQLBaseCol,
SessionId: "session_id",
NodeId: "node_id",
RequestContent: "request_content",
ResponseContent: "response_content",
Remark: "remark",

View File

@@ -1,22 +1,21 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
BuildType int `orm:"build_type" json:"buildType"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
GatewayState int `orm:"gateway_state" json:"gatewayState"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
ResultText string `orm:"result_text" json:"resultText"`
Messages any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
BuildType int `orm:"build_type" json:"buildType"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
GatewayState int `orm:"gateway_state" json:"gatewayState"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
}
type composeTaskCol struct {
@@ -28,8 +27,7 @@ type composeTaskCol struct {
CallbackUrl string
GatewayState string
RequestPayload string
ResultText string
Messages string
ResultJson string
Status string
ErrorMessage string
OssFile string
@@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{
CallbackUrl: "callback_url",
GatewayState: "gateway_state",
RequestPayload: "request_payload",
ResultText: "result_text",
Messages: "messages",
ResultJson: "result_json",
Status: "status",
ErrorMessage: "error_message",
OssFile: "oss_file",

View File

@@ -1,21 +1,21 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
// ProviderProtocol 模型协议映射配置
type ProviderProtocol struct {
beans.SQLBaseDO `orm:",inherit"`
// 业务字段
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
Capabilities any `orm:"capabilities" json:"capabilities"`
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
}
// providerProtocolCol 列名

View File

@@ -4,10 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"prompts-core/common/util"
"prompts-core/model/entity"
"strings"
commonHttp "gitea.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/beans"
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
@@ -37,6 +41,70 @@ func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, err
return req.TaskId, nil
}
type GetModelConfigResp struct {
Model *AsynchModel `json:"model"`
}
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
CallModel int `orm:"call_model" json:"callModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
}
// GetModelConfig 获取模型配置
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
fullURL := "model-gateway/model/getModel"
// 拼接 query 参数
var params []string
if req.Creator != "" {
params = append(params, fmt.Sprintf("creator=%s", req.Creator))
}
if req.ModelName != "" {
params = append(params, fmt.Sprintf("modelName=%s", req.ModelName))
}
if req.IsChatModel != 0 {
params = append(params, fmt.Sprintf("isChatModel=%d", req.IsChatModel))
}
if len(params) > 0 {
fullURL += "?" + strings.Join(params, "&")
}
headers := util.ForwardHeaders(ctx)
var resp GetModelConfigResp
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
return nil, fmt.Errorf("获取模型配置失败: %w", err)
}
if resp.Model == nil {
return nil, fmt.Errorf("模型不存在")
}
return resp.Model, nil
}
// GetTaskResultRes 任务结果响应
type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
@@ -80,78 +148,48 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
// SendCallbackReq 发送回调的请求体
type SendCallbackReq struct {
TaskId string `json:"taskId"`
Status string `json:"status"`
Messages *MultiRoundResult `json:"messages,omitempty"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
TaskId string `json:"taskId"`
Status string `json:"status"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
// SendCallback 向业务方发送回调
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
// 1. 检查回调地址
if composeTask.CallbackUrl == "" {
return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId)
}
// 2. 构造请求体
req := SendCallbackReq{
TaskId: composeTask.TaskId,
Status: composeTask.Status,
Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体
ErrorMsg: composeTask.ErrorMessage,
TaskId: composeTask.TaskId,
Status: composeTask.Status,
ErrorMsg: composeTask.ErrorMessage,
EpicycleId: epicycleId,
}
// 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx)
var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
composeTask.TaskId, composeTask.CallbackUrl, req.Messages)
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
composeTask.TaskId, composeTask.CallbackUrl)
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
}
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl)
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
return nil
}
// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult
func parseMessagesToResult(messages any) *MultiRoundResult {
if messages == nil {
return nil
// DownloadFile 从 OSS 下载文件内容
func DownloadFile(ossURL string) ([]byte, error) {
resp, err := http.Get(ossURL)
if err != nil {
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
}
var result MultiRoundResult
switch v := messages.(type) {
case *MultiRoundResult:
return v
case MultiRoundResult:
return &v
case string:
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil
}
case []byte:
if err := json.Unmarshal(v, &result); err != nil {
return nil
}
case map[string]any:
// 通过 JSON 序列化再反序列化
data, _ := json.Marshal(v)
if err := json.Unmarshal(data, &result); err != nil {
return nil
}
default:
data, err := json.Marshal(v)
if err != nil {
return nil
}
if err = json.Unmarshal(data, &result); err != nil {
return nil
}
}
return &result
return io.ReadAll(resp.Body)
}

View File

@@ -2,9 +2,8 @@ package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"prompts-core/service/gateway"
"strings"
"prompts-core/common/util"
@@ -12,181 +11,141 @@ import (
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
)
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
//2) 检查整体内容是否超出窗口
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel)
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel, req)
}
return compileToProviderRequest(ctx, ir, chatModel)
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
return nil, err
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
return nil, fmt.Errorf("协议配置不存在或获取失败")
}
// 如果传了自定义提示词,替换掉协议模板
if len(customPrompt) > 0 && customPrompt[0] != "" {
protocol.SystemPromptTemplate = customPrompt[0] +
"【核心铁律】" +
"1.【技能内容skill相关】必须完整拼接到System提示词中作为System提示词的组成部分不得拆分到其他位置。"
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"bizName": util.GetServerName(ctx),
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
"buildType": req.BuildType,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
// promptBuildWithRounds 构建提示词
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
ProviderName: chatModel.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName, // %s 目标模型名称
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
totalRounds, // %d 数组长度(多轮输出要求)
totalRounds, // %d 数组长度(结构铁律)
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
totalRounds, // %d 数组长度(最后一行)
outputJSON, //【输出结构】 %s
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
var b strings.Builder
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
if prompt != "" {
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
}
return util.MustMarshal(payload)
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
}
if formText := buildUserFormText(req.Form); formText != "" {
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
}
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
}
if len(req.Consult) > 0 {
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
}
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
}
return b.String()
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
func buildUserFormText(form []map[string]any) string {
if len(form) == 0 {
return ""
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
for _, item := range form {
for k, v := range item {
builder.WriteString(fmt.Sprintf("%s\n", k))
switch val := v.(type) {
case []any:
for i, elem := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
if m, ok := elem.(map[string]any); ok {
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
} else {
builder.WriteString(fmt.Sprint(elem))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf(" %v\n", v))
}
}
builder.WriteString(text)
}
return builder.String()
return strings.TrimSpace(builder.String())
}
// NodeBuild 节点构建
@@ -195,9 +154,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
return fmt.Sprintf(promptTpl,
gjson.New(req.Form).MustToJsonString(),
gjson.New(req.UserForm).MustToJsonString(),
)
}

View File

@@ -2,13 +2,9 @@ package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/service/session"
"prompts-core/common/util"
"prompts-core/consts/public"
@@ -16,49 +12,55 @@ import (
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
// 1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
// 2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
return handleBuild(ctx, req, chatModel, aiModel)
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := getChatModel(ctx, userInfo.UserName)
if err != nil {
return nil, nil, err
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil || chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil || aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
@@ -72,274 +74,244 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId)
// handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
// 2) 构建推理请求
ir := NewPromptIR()
var taskReq map[string]any
switch req.BuildType {
case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct:
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
// 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
return nil, fmt.Errorf("构建推理请求失败: %w", err)
}
// 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
// 3) 调用网关创建任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
return nil, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return nil, errors.New("网关未返回taskId")
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
// 4) 保存任务记录
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshal(req),
RequestPayload: gconv.Map(req),
Status: public.ComposeStatusPending,
})
return err
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
}
return taskID, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{data},
}); err != nil {
return nil, err
}
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
// 1) 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
if composeTask == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
}
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = parsePromptResult(req.Text)
case public.BuildTypeNode: // 节点构建解析
messages = parseNodeResult(req.Text)
default:
messages = req.Text
}
// 2. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
})
// 2) 读取 OSS 文件内容
var ossContent []byte
if req.OssFile != "" {
ossContent, err = gateway.DownloadFile(req.OssFile)
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
}
gateway.SendCallback(ctx, successTask)
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
}
}
// 3) 解析 OSS 内容为消息
var messages map[string]any
if len(ossContent) > 0 {
messages, _ = gjson.New(ossContent).Map(), nil
}
// 4) 处理失败
if req.State == 3 {
return handleCallbackFailed(ctx, req, composeTask, messages)
}
// 5) 处理成功
if req.State == 2 {
return handleCallbackSuccess(ctx, req, composeTask, messages)
}
return nil
}
// handleCallbackFailed 处理回调失败
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: messages,
})
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed
composeTask.ErrorMessage = req.ErrorMsg
_ = gateway.SendCallback(ctx, composeTask, 0)
}
return err
}
// parsePromptResult 解析提示词构建结果
func parsePromptResult(raw string) *dto.MultiRoundResult {
var wrapper map[string]any
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
// 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
contentStr, ok := wrapper["content"].(string)
if !ok || contentStr == "" {
return createDefaultResult(wrapper)
}
// 2) 获取协议配置
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
// 3) 获取历史消息 + 保存当前轮
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage
var epicycleId int64
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
// 3.1 获取历史
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
})
if h != nil {
history = h.Messages
}
// 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{singleRound},
// 4) 合并附加结构
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 5) 注入历史
if len(history) > 0 {
messages = InjectHistory(messages, history, protocol)
}
// 6) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: messages,
})
if err != nil {
return err
}
// 8) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.ResultJson = messages
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
}
return nil
}
// InjectHistory 插入历史会话
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 {
return roundsData
}
// 1) 提取第一轮的 messages
rounds := roundsData["rounds"].([]any)
firstRound := rounds[0].(map[string]any)
original := firstRound["messages"].([]any)
// 2) 按 merge_order 拼接
result := make([]any, 0, len(original)+len(history))
for _, part := range protocol.MergeOrder {
switch part {
case "system":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "system" {
result = append(result, msg)
}
}
case "history":
if gconv.Bool(protocol.Capabilities["support_history"]) {
for _, msg := range history {
result = append(result, map[string]any{
"role": msg.Role,
"content": msg.Content, // 纯字符串,不转换
})
}
}
case "user":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "user" {
result = append(result, msg)
}
}
}
}
return createDefaultResult(map[string]any{"content": contentStr})
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
// parseNodeResult 解析节点构建结果
func parseNodeResult(raw string) *dto.MultiRoundResult {
var result map[string]any
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner
// 3) 角色映射
if len(protocol.RoleMapping) > 0 {
for _, m := range result {
msg := m.(map[string]any)
role := gconv.String(msg["role"])
if mapped, ok := protocol.RoleMapping[role]; ok {
msg["role"] = mapped
}
}
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []map[string]any{result},
}
// 4) 直接修改原对象
firstRound["messages"] = result
return roundsData
}
// GetComposeTask 查询任务结果
@@ -350,31 +322,10 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
Messages: record.ResultJson,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"prompts-core/model/dto"
"strings"
"time"
@@ -21,15 +22,25 @@ const (
bytesPerMB = 1024 * 1024
)
// FetchFileTexts 从 URL 列表取文件内容,支持 zip 内文件
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string)
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
urls := make([]string, 0, len(consult))
for _, item := range consult {
if item.Url != "" {
urls = append(urls, item.Url)
}
}
return FetchFileTextsAsString(ctx, urls)
}
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
if len(urls) == 0 {
return result
return ""
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
var builder strings.Builder
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
@@ -38,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
}
if util.IsZipExtension(url) {
mergeMap(result, fetchZipFileTexts(ctx, client, url))
for _, text := range fetchZipFileTexts(ctx, client, url) {
builder.WriteString(text)
builder.WriteString("\n")
}
continue
}
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
}
}
return result
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
return builder.String()
}
// fetchAndCleanFileContent 获取并清理文件内容
@@ -182,10 +189,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
return strings.TrimSpace(string(body)), nil
}
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
func SkillMdContent(ctx context.Context, skillName string) string {
if skillName == "" {
return ""
}
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
return ""
}
@@ -196,11 +206,13 @@ func SkillMdContent(ctx context.Context, skillName string) string {
zipBytes, err := downloadFile(client, fullUrl, maxSize)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] 下载失败 url=%s err=%v", fullUrl, err)
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil || len(mdContents) == 0 {
g.Log().Warningf(ctx, "[SkillMd] 提取md失败 count=%d err=%v", len(mdContents), err)
return ""
}

View File

@@ -1,75 +0,0 @@
# Prompts-Core提示词核心服务
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
---
## 项目简介
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
### 核心价值
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
---
## 核心功能
### 1. 提示词构建引擎
#### 多模态支持
| 类型 | 说明 | 适用场景 |
|------|------|----------|
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
#### 构建模式
- **BuildType 1提示词构建**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
- **BuildType 2节点构建**:工作流路由决策,根据上下文选择节点 ID
#### 分批处理
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
### 2. 会话管理系统
- **双层存储**Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
### 3. 协议适配器
通过配置动态支持多种模型协议:
- 角色映射system/user/assistant → 目标协议角色
- 内容字段映射content → parts.text 等
- 消息顺序控制:灵活配置拼接顺序
- 请求模板渲染:支持占位符替换
### 4. 任务调度
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
- **重试机制**:可配置最大重试次数(默认 3 次)
- **超时保护**:默认 300 秒超时
---
## 技术架构
### 技术栈
| 组件 | 版本 | 用途 |
|------|------|------|
| Go | 1.26.0 | 编程语言 |
| GoFrame | v2.10.0 | Web 框架 |
| PostgreSQL | - | 关系型数据库 |
| Redis | - | 缓存与会话存储 |
| Consul | - | 服务注册与发现 |
| Jaeger | - | 分布式链路追踪 |
### 架构图

View File

@@ -2,17 +2,18 @@ package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
// IR 统一 Prompt 中间表示
type IR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
@@ -33,6 +34,7 @@ type ProviderProtocol struct {
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
Capabilities map[string]any `json:"capabilities"`
}
// ContentMapping 内容字段映射
@@ -42,8 +44,8 @@ type ContentMapping struct {
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
func NewPromptIR() *IR {
return &IR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
@@ -51,7 +53,7 @@ func NewPromptIR() *PromptIR {
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
func (ir *IR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -77,7 +79,7 @@ func (ir *PromptIR) String() string {
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
func (ir *IR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -99,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
func (ir *IR) AddSystem(content string) *IR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
@@ -107,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
func (ir *IR) AddUser(content string) *IR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
@@ -115,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *IR) AddHistory(role, content string) *IR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
@@ -123,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
func (ir *IR) ToMessages() []map[string]any {
var messages []map[string]any
for _, seg := range ir.System {
@@ -164,21 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
return &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
MergeOrder: e.MergeOrder,
RoleMapping: gconv.MapStrStr(e.RoleMapping),
ContentMapping: ContentMapping{
Type: gconv.String(e.ContentMapping["type"]),
Field: gconv.String(e.ContentMapping["field"]),
},
RequestTemplate: e.RequestTemplate,
Capabilities: e.Capabilities,
}
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
return p
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -190,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
func mergeByOrder(ir *IR, order []string) []map[string]any {
roleMap := map[string][]Segment{
"system": ir.System,
"history": ir.History,
"user": ir.User,
}
var messages []map[string]any
for _, part := range order {
for _, seg := range roleMap[part] {
msg := map[string]any{"content": seg.Content}
if part == "history" {
msg["role"] = seg.Role
} else {
msg["role"] = part
}
messages = append(messages, msg)
}
}
return messages
}
@@ -242,29 +235,29 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
msg[cm.Field] = content
}
if cm.Field == "" || cm.Field == "content" {
return messages
}
for i, msg := range messages {
if content, ok := msg["content"]; ok {
delete(msg, "content")
switch cm.Type {
case "parts":
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
default:
messages[i][cm.Field] = content
}
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
return renderTemplate(p, messages, chatModel)
}
return map[string]any{
@@ -272,20 +265,21 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// renderTemplate 模板渲染
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
result := make(map[string]any, len(p.RequestTemplate)+1)
for k, v := range p.RequestTemplate {
result[k] = v
}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
if chatModel != nil {
result["model"] = chatModel.ModelName
}
result["messages"] = messages
var result map[string]any
json.Unmarshal([]byte(str), &result)
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens
}
return result
}

View File

@@ -1,145 +0,0 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/gogf/gf/v2/frame/g"
)
const (
redisKeyPrefix = "chat:session:%s"
)
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := formatRedisKey(sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err
}
return nil
}
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := formatRedisKey(sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
}
sessions := parseRedisSessions(ctx, result.Strings())
reverseSlice(sessions)
return sessions, nil
}
// parseRedisSessions 解析Redis会话数据
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
var sessions []map[string]any
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
return sessions
}
// reverseSlice 反转切片
func reverseSlice(s []map[string]any) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
return []map[string]any{}, nil
}
return flattenHistoryMessages(historyData), nil
}
// flattenHistoryMessages 扁平化历史消息
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
var messages []map[string]any
for _, round := range historyData {
appendMessagesFromField(round, "requestContent", &messages)
appendMessagesFromField(round, "responseContent", &messages)
}
return messages
}
// appendMessagesFromField 从指定字段追加消息
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]interface{})
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
*messages = append(*messages, msg)
}
}
}

View File

@@ -1,165 +0,0 @@
package prompt
import (
"context"
"fmt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// SessionCallback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text)
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
result["role"] = "assistant"
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
return nil, err
}
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
ResponseContent: response,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err)
}
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
return session, nil
}
// saveSessionToRedis 保存会话到Redis
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
}
return nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
}
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
messages := extractMessagesFromSessions(sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
}
return messages
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
*messages = append(*messages, m)
}
}
}
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
*messages = append(*messages, m)
}
}
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
}
}
}

View File

@@ -3,17 +3,17 @@ package prompt
import (
"context"
"fmt"
"prompts-core/service/gateway"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}

View File

@@ -0,0 +1,151 @@
package session
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/model/dto"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
const (
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId}
RedisKeySessionHistory = "session:history:%d:%s:%s"
)
// formatRedisKey 格式化 Redis key
func formatRedisKey(tenantID uint64, sessionID, nodeID string) string {
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID)
}
// ============================================
// 写操作
// ============================================
// SaveToRedis 保存一轮对话到 Redis ZSET
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error {
key := formatRedisKey(tenantID, sessionID, nodeID)
maxRounds := util.GetMaxRounds(ctx)
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
b, err := json.Marshal(round)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
score := float64(time.Now().UnixMilli())
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
return fmt.Errorf("ZADD失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
return fmt.Errorf("裁剪失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期失败: %w", err)
}
return nil
}
// DeleteSessionHistory 删除整个 session 下所有 node 的缓存
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*")
keys, err := g.Redis().Do(ctx, "KEYS", pattern)
if err != nil {
return err
}
for _, key := range keys.Strings() {
_, _ = g.Redis().Do(ctx, "DEL", key)
}
return nil
}
// DeleteRedisMessages 批量删除指定 node 下的消息
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error {
key := formatRedisKey(tenantID, sessionID, nodeID)
for _, msgID := range msgIDs {
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
break
}
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
for _, member := range parts[1:] {
_, _ = g.Redis().Do(ctx, "ZREM", key, member)
}
if cursor == "0" {
break
}
}
}
return nil
}
// ============================================
// 读操作
// ============================================
// GetFromRedis 从 Redis ZSET 获取会话历史
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) {
key := formatRedisKey(tenantID, sessionID, nodeID)
maxRounds := util.GetMaxRounds(ctx)
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
if err != nil {
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
}
if result == nil || result.IsNil() {
return []dto.HistoryRound{}, nil
}
return parseRounds(result.Strings()), nil
}
// ============================================
// 解析
// ============================================
func parseRounds(members []string) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(members))
for _, member := range members {
var round dto.HistoryRound
if err := json.Unmarshal([]byte(member), &round); err != nil {
continue
}
if round.User != nil || round.Assistant != nil {
rounds = append(rounds, round)
}
}
return rounds
}
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
var messages []dto.FlatMessage
for i := len(rounds) - 1; i >= 0; i-- {
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].User["role"]),
Content: gconv.String(rounds[i].User["content"]),
})
}
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].Assistant["role"]),
Content: gconv.String(rounds[i].Assistant["content"]),
})
}
}
return messages
}

View File

@@ -0,0 +1,191 @@
package session
import (
"context"
"fmt"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ============================================
// 回调存储
// ============================================
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
req.Messages["role"] = "assistant"
// 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("更新数据库失败: %w", err)
}
// 2) 查询完整记录
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if err != nil || session == nil {
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
// 3) entity → HistoryRound → 写入 Redis
round := entityToHistoryRound(session)
round.Assistant = req.Messages
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
}
// ============================================
// 场景1前端历史列表按 creator
// ============================================
// GetHistoryList 获取历史列表
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
}, req.Page, req.Size)
if err != nil {
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
}
rounds := sessionsToHistoryRounds(sessions)
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
}
// ============================================
// 场景2提示词拼接按 sessionId + nodeId
// ============================================
// GetHistoryMessages 获取历史消息Redis → DB → 异步回种)
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 1) Redis
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 {
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// 2) DB
maxRounds := util.GetMaxRounds(ctx)
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
SessionId: req.SessionId,
NodeId: req.NodeId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
if len(sessions) == 0 {
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
}
// 3) 转换 + 异步回种
rounds := sessionsToHistoryRounds(sessions)
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds)
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// ============================================
// 删除
// ============================================
// DeleteMessages 删除消息
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
if len(req.MsgIds) == 0 {
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
}
user, _ := utils.GetUserInfo(ctx)
// 1) 批量查询
sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
// 2) 批量删 DB
_, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
// 3) 按 nodeId 分组删 Redis
for _, s := range sessions {
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds)
}
return &dto.DeleteMessagesRes{Ok: true}, nil
}
// DeleteSession 删除整个会话
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
// 1) 删 DB
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
}); err != nil {
return nil, fmt.Errorf("DB删除失败: %w", err)
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 2) 删 Redis
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
}
return &dto.DeleteSessionRes{Ok: true}, nil
}
// ============================================
// 转换方法entity ↔ dto集中管理
// ============================================
// entityToHistoryRound entity → HistoryRound
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
return &dto.HistoryRound{
Id: s.Id,
SessionId: s.SessionId,
NodeId: s.NodeId,
CreatedAt: gconv.String(s.CreatedAt),
UpdatedAt: gconv.String(s.UpdatedAt),
User: s.RequestContent,
Assistant: s.ResponseContent,
}
}
// sessionsToHistoryRounds 批量转换
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(sessions))
for _, s := range sessions {
rounds = append(rounds, *entityToHistoryRound(s))
}
return rounds
}
// asyncCacheToRedis 异步缓存到 Redis
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) {
for i := range rounds {
if rounds[i].User != nil || rounds[i].Assistant != nil {
_ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i])
}
}
}