Compare commits

15 Commits

Author SHA1 Message Date
0d52b631b9 refactor(task): 重构任务服务和数据结构 2026-06-12 15:29:06 +08:00
c22d578e1a fix(task): 修复任务状态更新和超时处理问题 2026-06-11 11:27:15 +08:00
df26329836 feat(session): 重构会话服务支持节点维度的Redis缓存管理 2026-06-10 16:48:35 +08:00
40abf0f606 ci/cd调整 2026-06-10 16:32:42 +08:00
b69e7386e2 refactor(prompts-core): 重构代码结构和优化工具函数 2026-06-10 14:51:25 +08:00
1c1db7e30c feat(prompt): 实现历史消息注入功能和协议配置优化
- 在 handleCallbackSuccess 函数中新增获取协议配置逻辑
- 实现历史消息获取并在 rounds 中注入历史消息
- 添加 InjectHistory 函数实现按协议顺序合并历史消息
- 在 GetPromptText 接口中集成历史消息注入测试
- 更新 ProviderProtocol 实体中的 MergeOrder 类型为 []string
- 新增 Capabilities 字段支持最大 token 配置
- 修改 renderTemplate 函数接收协议对象参数
- 优化会话历史存储逻辑,提取用户消息内容进行记录
- 移除无用的注释代码 handleCallbackSuccess 处理回调成功
2026-06-10 10:16:58 +08:00
78114f99c7 feat(session): 重构会话管理和消息存储功能 2026-06-09 15:46:09 +08:00
9410199fbe feat(session): 重构会话管理和Redis缓存机制 2026-06-09 14:00:01 +08:00
1f9a2b9b5f Merge remote-tracking branch 'origin/dev' into dev 2026-06-08 18:02:27 +08:00
e1461cf0f0 feat: 重构异步模型字段并更新依赖 2026-06-08 18:01:54 +08:00
aa7804656f ci/cd调整 2026-06-08 15:37:12 +08:00
5494a0c480 ci/cd调整 2026-06-08 13:44:54 +08:00
qhd
ee6677c1f8 fix: 修复响应体解析逻辑并统一结构包装 2026-06-05 11:48:27 +08:00
de70d33115 refactor(prompt): 重构提示词构建服务和回调处理 2026-06-05 11:00:05 +08:00
b2cad4cac2 refactor(model-gateway): 重构代码结构并优化数据库查询 2026-06-03 18:37:18 +08:00
28 changed files with 947 additions and 1267 deletions

24
Dockerfile Normal file
View File

@@ -0,0 +1,24 @@
# 阶段1: 构建
FROM golang:alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
ENV GO111MODULE=on
ENV GOPROXY=https://goproxy.cn,direct
ENV CGO_ENABLED=0
ENV GOTOOLCHAIN=auto
WORKDIR /build
COPY . .
RUN go mod download && go mod tidy
RUN go build -ldflags="-s -w" -o main ./main.go
EXPOSE 3009
CMD ["./main"]

View File

@@ -2,7 +2,6 @@ package util
import (
"context"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -13,16 +12,6 @@ func GetServerName(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "server.name", "").String()
}
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
// address 格式如 ":3009",去掉冒号
if strings.HasPrefix(address, ":") {
return address[1:]
}
return "8080"
}
// GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, modelType int) string {
key := "modelPrompts.types." + gconv.String(modelType)
@@ -33,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string {
func GetBuildPrompt(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
}
// GetMaxRounds 获取最大轮数配置
func GetMaxRounds(ctx context.Context) int {
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
}
// GetExpireMinutes 获取过期时间配置
func GetExpireMinutes(ctx context.Context) int {
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
}

View File

@@ -3,7 +3,7 @@ package util
import (
"context"
"gitea.com/red-future/common/utils"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)

View File

@@ -1,233 +1,81 @@
package util
import (
"encoding/json"
"strconv"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
tGjson "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j := gfgjson.New(raw)
messages := j.Get("messages")
if !messages.IsNil() {
return gconv.Maps(messages.Val())
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form []map[string]any) string {
if form == nil {
return "[]"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshalToMap 将对象序列化为 map[string]any失败时返回空 map
func MustMarshalToMap(v any) map[string]any {
b, err := json.Marshal(v)
if err != nil {
return make(map[string]any)
}
var m map[string]any
json.Unmarshal(b, &m)
return m
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中。
//
// 参数说明:
// - req: 请求参数 map需包含 "consult" 字段,值为 []any每个元素是 {"type":"xxx","url":"..."}
// - messages: 模型生成的返回结构(如 rounds[...].messages[...].content 数组)
// - extendMapping: 附加映射配置,格式:
// {"attachments": {"image": {"template": {...}, "target_path": "...", "field_mapping": {...}}, ...}}
//
// 返回值:合并后的完整 map。
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
return messages
}
reqJSON, _ := json.Marshal(req)
msgJSON, _ := json.Marshal(messages)
extJSON, _ := json.Marshal(extendMapping)
reqStr := string(reqJSON)
msgStr := string(msgJSON)
extStr := string(extJSON)
// 获取 consult 数组
consultResult := tGjson.Get(reqStr, "consult")
if !consultResult.Exists() || !consultResult.IsArray() {
consult := gconv.Interfaces(req["consult"])
if len(consult) == 0 {
return messages
}
// 获取 attachments 配置
attachmentsResult := tGjson.Get(extStr, "attachments")
if !attachmentsResult.Exists() || !attachmentsResult.IsObject() {
targetPath := gconv.String(extendMapping["target_content_path"])
templates := gconv.Map(extendMapping["attachment_templates"])
if targetPath == "" || len(templates) == 0 {
return messages
}
consultArr := consultResult.Array()
attachmentsMap := attachmentsResult.Map()
msgJson := gjson.New(messages)
for _, consultItem := range consultArr {
if !consultItem.IsObject() {
// rounds 路径修正
if !msgJson.Get("rounds.0").IsNil() {
targetPath = "rounds.0." + targetPath
}
// 遍历追加
for _, item := range consult {
itemJson := gjson.New(item)
itemType := itemJson.Get("type").String()
tmpl := gconv.Map(templates[itemType])
if itemType == "" || len(tmpl) == 0 {
continue
}
itemType := consultItem.Get("type").String()
if itemType == "" {
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
if attachment == nil {
continue
}
// 查找对应类型的附件配置
attachResult, ok := attachmentsMap[itemType]
if !ok || !attachResult.IsObject() {
continue
}
idx := len(msgJson.Get(targetPath).Array())
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
}
// 获取模板
templateResult := attachResult.Get("template")
if !templateResult.Exists() || !templateResult.IsObject() {
continue
}
return msgJson.Map()
}
// 深拷贝模板
filledTemplateStr := templateResult.Raw
func buildAttachment(tmpl map[string]any, url string) map[string]any {
typ := gconv.String(tmpl["type"])
if typ == "" || url == "" {
return nil
}
// 应用字段映射
fieldMappingResult := attachResult.Get("field_mapping")
if fieldMappingResult.Exists() && fieldMappingResult.IsObject() {
fieldMapping := fieldMappingResult.Map()
for fieldPath, valueSource := range fieldMapping {
sourceKey := valueSource.String()
valueResult := consultItem.Get(sourceKey)
if valueResult.Exists() {
var err error
filledTemplateStr, err = sjson.SetRaw(filledTemplateStr, fieldPath, valueResult.Raw)
if err != nil {
continue
}
}
body := gconv.Map(tmpl["body"])
fillEmptyInPlace(body, url)
return map[string]any{
"type": typ,
typ: body,
}
}
func fillEmptyInPlace(m map[string]any, value string) {
for k, v := range m {
switch vv := v.(type) {
case string:
if vv == "" {
m[k] = value
}
}
// 获取目标路径
targetPath := attachResult.Get("target_path").String()
if targetPath == "" {
continue
}
// 检查目标路径是否存在且为数组
targetResult := tGjson.Get(msgStr, targetPath)
if !targetResult.Exists() || !targetResult.IsArray() {
continue
}
// 追加到数组末尾
arrLen := len(targetResult.Array())
appendPath := targetPath + "." + strconv.Itoa(arrLen)
var err error
msgStr, err = sjson.SetRaw(msgStr, appendPath, filledTemplateStr)
if err != nil {
continue
case map[string]any:
fillEmptyInPlace(vv, value)
}
}
// 转回 map[string]any
var result map[string]any
if err := json.Unmarshal([]byte(msgStr), &result); err != nil {
return messages
}
return result
}
// GetUserMessage 获取用户消息
func GetUserMessage(taskReq map[string]any) map[string]any {
// 先取 requestPayload
rp, ok := taskReq["requestPayload"].(map[string]any)
if !ok {
return nil
}
// 再取 messages
messages, ok := rp["messages"].([]any)
if !ok {
return nil
}
for _, msg := range messages {
m, ok := msg.(map[string]any)
if ok && m["role"] == "user" {
return m
}
}
return nil
}

View File

@@ -1,7 +1,6 @@
package util
import (
"net/url"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
@@ -22,86 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
return jsonObj.Map()
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
if len(mapping) == 0 {
return responseBytes, nil
// ExtractUserText 从 messages 中提取所有 user 文本
func ExtractUserText(messages map[string]any) map[string]any {
msgJson := gjson.New(messages)
msgs := msgJson.Get("rounds.0.messages")
if msgs.IsNil() {
msgs = msgJson.Get("messages")
}
responseJson := gjson.New(responseBytes)
resultJson := gjson.New("{}")
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
var texts []string
for _, m := range msgs.Array() {
msg := gjson.New(m)
if msg.Get("role").String() != "user" {
continue
}
val := responseJson.Get(path)
if val.IsNil() {
continue
}
resultJson.Set(standardField, val.Val())
}
return []byte(resultJson.String()), nil
}
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func ParseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
content := msg.Get("content").Val()
switch c := content.(type) {
case string:
texts = append(texts, c)
case []any:
for _, item := range c {
if m, ok := item.(map[string]any); ok {
if t := gconv.String(m["text"]); t != "" {
texts = append(texts, t)
}
}
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
// PayloadToQuery 将 payload 转为 url.Values
func PayloadToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
return map[string]any{
"role": "user",
"content": strings.Join(texts, "\n"),
}
return q, nil
}

View File

@@ -1,130 +0,0 @@
package util
import (
"context"
"net"
"strings"
"github.com/gogf/gf/v2/frame/g"
)
// GetLocalIP 获取本机有效的局域网 IPv4 地址
func GetLocalIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
var validIPs []string
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP
if isIPValid(ip) {
validIPs = append(validIPs, ip.String())
}
}
// 优先返回非 169.254.x.x 的 IP
for _, ip := range validIPs {
if !strings.HasPrefix(ip, "169.254.") {
return ip
}
}
// 其次返回 169.254.x.x最后的选择
if len(validIPs) > 0 {
return validIPs[0]
}
return "127.0.0.1"
}
// isIPValid 判断 IP 是否有效
func isIPValid(ip net.IP) bool {
// 不是 loopback (127.0.0.1)
if ip.IsLoopback() {
return false
}
// 是 IPv4
if ip.To4() == nil {
return false
}
// 不是链路本地地址 (169.254.0.0/16)
if ip[0] == 169 && ip[1] == 254 {
return false
}
// 不是组播地址
if ip.IsMulticast() {
return false
}
// 不是未指定地址 (0.0.0.0)
if ip.IsUnspecified() {
return false
}
return true
}
// GetLocalAddress 获取局域网地址IP:端口)
func GetLocalAddress(ctx context.Context) string {
ip := GetLocalIP()
port := GetServerPort(ctx)
if port == "80" || port == "443" {
return ip
}
return ip + ":" + port
}
// GetSchemaFromRequest 从当前请求中获取协议http/https
func GetSchemaFromRequest(ctx context.Context) string {
r := g.RequestFromCtx(ctx)
if r == nil {
return "http"
}
// 1. 代理场景X-Forwarded-Proto
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
// 2. 代理场景X-Forwarded-Scheme
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
return proto
}
// 3. TLS 连接(直接 HTTPS
if r.TLS != nil {
return "https"
}
// 4. 默认 HTTP这行很重要
return "http" // ← 确保有这行
}
// GetLocalBaseURL 获取局域网基础 URL动态协议 + IP + 端口)
func GetLocalBaseURL(ctx context.Context) string {
schema := GetSchemaFromRequest(ctx)
addr := GetLocalAddress(ctx)
return schema + "://" + addr
}
// GetCallbackURL 获取回调地址(完整 URL
func GetCallbackURL(ctx context.Context, path string) string {
baseURL := GetLocalBaseURL(ctx)
// 确保 path 以 / 开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}

View File

@@ -112,41 +112,3 @@ nodePrompts: |
%s
上下文内容:
%s
#你是专业的JSON结构生成专家必须严格遵守以下全部规则。
# 【强制规则】
# 必须根据【输出结构】里面返回的JSON结构进行生成不得任何更改最终内容与输出结构返回一致
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
# 完整读取UserForm所有字段不得忽略任何字段
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中
# 理解全部语义后再输出,禁止断章取义;
# UserForm所有字段内容必须完整拼接赋值到user角色描述中不得有任何遗漏。
# 【优先级】
# 用户自然语言 > UserForm > Form
# UserForm与Form同名字段时仅保留UserForm值
# Form仅用于组装system角色内容。
# 【表单处理】
# Form系统提示词、默认参数、基础配置 → 专属填充system角色
# UserForm用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content
# 自动提取UserForm中每条文案的配图数量总图片数 = 各文案配图数累加求和用户没有相关数量必须默认1
# 图片尺寸为空时自动填充size=1024*1024。
# 【结构铁律】
# 严格沿用固定输出结构,不增删字段或修改层级;
# messages元素必须按结构返回
# 禁止将role对象转为字符串、禁止嵌套错乱
# 输出纯净JSON无多余转义符、无换行符、无额外字符
# 所有括号、引号必须成对闭合保证JSON合法。
# 【参数赋值】
# model固定沿用传入值
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充
# 从UserForm提取信息整合进user描述确保数量、尺寸、文案语义无遗漏。
# 【输出要求】
# 仅输出单行纯净JSON无任何解释、备注、Markdown或多余符号
# 完整合UserForm全部字段语义到user描述
# 生成后自检JSON语法、结构、数量错误则自动重新生成。
# 【输出结构】
# %s
# 【完整输入信息】
# %s
# 直接输出最终JSON

View File

@@ -9,4 +9,9 @@ const (
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建
BuildTypeStruct = 3 //结构构建
)
const (
ModelTypeInference = 100 // 推理模型
)

View File

@@ -2,17 +2,8 @@ package controller
import (
"context"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
promptService "prompts-core/service/prompt"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
)
type prompt struct{}
@@ -35,31 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId)
}
func (c *prompt) Text(ctx context.Context, req *dto.TextReq) (res *dto.TextRes, err error) {
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: "c58c9296-994f-4e83-8285-1daebf3c492d",
})
if err != nil {
return
}
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})
if err != nil {
return
}
message := promptService.ParsePromptResult(composeTask.ResultText)
// 加这两行
g.Log().Infof(ctx, "[Text] RequestPayload.consult: %v", composeTask.RequestPayload["consult"])
g.Log().Infof(ctx, "[Text] ExtendMapping: %v", model.ExtendMapping)
messages := util.MergeConsult(composeTask.RequestPayload, message, model.ExtendMapping)
g.Log().Infof(ctx, "[Text] MergeConsult 结果 rounds[0].messages[0].content: %v",
gjson.New(messages).Get("rounds.0.messages.0.content"))
res = &dto.TextRes{
Messages: messages,
}
return
}

View File

@@ -1,18 +1,36 @@
// ============================================
// controller/session.go
// ============================================
package controller
import (
"context"
"prompts-core/model/dto"
"prompts-core/model/dto"
sessionService "prompts-core/service/session"
)
type session struct{}
// Session 提示词会话控制器
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return sessionService.Callback(ctx, req)
}
// GetHistoryList 获取历史列表(前端列表)
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
return sessionService.GetHistoryList(ctx, req)
}
// DeleteMessages 批量删除消息
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
return sessionService.DeleteMessages(ctx, req)
}
// DeleteSession 删除整个会话
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
return sessionService.DeleteSession(ctx, req)
}

View File

@@ -5,8 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var ComposeSession = &composeSessionDao{}
@@ -15,13 +14,8 @@ type composeSessionDao struct{}
// Insert 插入
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
var m = new(entity.ComposeSession)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Insert(m)
Insert(req)
if err != nil {
return
}
@@ -69,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.Creator, req.Creator).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Fields(fields).One()
if err != nil {
@@ -86,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.Creator, req.Creator).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Delete()
if err != nil {
@@ -93,3 +89,36 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
}
return r.RowsAffected()
}
// ListByIds 根据 ID 列表批量查询
func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) {
if len(ids) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
WhereIn(entity.ComposeSessionCol.Id, ids).
Where(entity.ComposeSessionCol.Creator, creator).
Where(entity.ComposeSessionCol.SessionId, sessionId).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// DeleteByIds 批量删除编排会话
func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
WhereIn(entity.ComposeSessionCol.Id, ids).
Where(entity.ComposeSessionCol.Creator, creator).
Where(entity.ComposeSessionCol.SessionId, sessionId).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}

View File

@@ -5,7 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)

View File

@@ -5,7 +5,7 @@ import (
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)

6
go.mod
View File

@@ -3,12 +3,10 @@ module prompts-core
go 1.26.1
require (
gitea.com/red-future/common v0.0.20
gitea.redpowerfuture.com/red-future/common v0.0.23
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
github.com/gogf/gf/v2 v2.10.2
github.com/tidwall/gjson v1.19.0
github.com/tidwall/sjson v1.2.5
)
require (
@@ -65,8 +63,6 @@ require (
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/redis/go-redis/v9 v9.12.1 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tiger1103/gfast-token v1.0.10 // indirect
github.com/vcaesar/cedar v0.30.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect

13
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
gitea.com/red-future/common v0.0.20 h1:KlKINnJFmOVkDzgkptEAFsdpMUZb0zK9BTdiXRxVfAo=
gitea.com/red-future/common v0.0.20/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -288,15 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=

View File

@@ -7,9 +7,9 @@ import (
"prompts-core/controller"
"syscall"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
_ "gitea.com/red-future/common/swagger"
"gitea.redpowerfuture.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/jaeger"
_ "gitea.redpowerfuture.com/red-future/common/swagger"
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
"github.com/gogf/gf/v2/frame/g"

View File

@@ -6,7 +6,8 @@ type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
NodeId string `p:"nodeId" json:"nodeId" dc:"节点ID"`
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
Form []map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
@@ -21,25 +22,16 @@ type ConsultItem struct {
Url string `json:"url" dc:"附件地址"`
}
type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"`
}
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
TaskId string `json:"taskId" dc:"任务ID"`
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
Messages map[string]any `json:"messages" dc:"消息数组"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"`
ErrorMsg string `json:"error_msg" dc:"错误信息"`
}
type CallbackRes struct {
@@ -51,19 +43,11 @@ type GetComposeTaskReq struct {
}
type GetComposeTaskRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}
type TextReq struct {
g.Meta `path:"/text" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
}
type TextRes struct {
Messages any `json:"messages" dc:"文本结果"`
TaskId string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages map[string]any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}

View File

@@ -2,13 +2,79 @@ package dto
import "github.com/gogf/gf/v2/frame/g"
type SessionCallbackReq struct {
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
Messages map[string]any `json:"messages" dc:"消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
// HistoryRound 一轮对话
type HistoryRound struct {
Id int64 `json:"id" dc:"记录ID"`
SessionId string `json:"sessionId" dc:"会话ID"`
NodeId string `json:"nodeId" dc:"节点ID"`
User map[string]any `json:"user" dc:"用户消息"`
Assistant map[string]any `json:"assistant" dc:"助手回复"`
CreatedAt string `json:"createdAt" dc:"创建时间"`
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
}
// SessionCallbackReq 会话回调请求
type SessionCallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
}
// SessionCallbackRes 会话回调响应
type SessionCallbackRes struct {
Status bool `json:"status" dc:"状态"`
SessionId string `json:"sessionId" dc:"会话ID"`
}
// GetHistoryListReq 获取历史列表请求(前端)
type GetHistoryListReq struct {
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
Page int `json:"page" d:"1" dc:"页码"`
Size int `json:"size" d:"10" dc:"每页条数"`
}
// GetHistoryListRes 获取历史列表响应
type GetHistoryListRes struct {
List []HistoryRound `json:"list" dc:"历史列表"`
Total int `json:"total" dc:"总数"`
}
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
type GetHistoryMessagesReq struct {
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
NodeId string `json:"nodeId" dc:"节点ID"`
}
// GetHistoryMessagesRes 获取历史消息响应
type GetHistoryMessagesRes struct {
Messages []FlatMessage `json:"messages"`
}
type FlatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// DeleteMessagesReq 批量删除消息请求
type DeleteMessagesReq struct {
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
}
// DeleteMessagesRes 批量删除消息响应
type DeleteMessagesRes struct {
Ok bool `json:"ok" dc:"是否成功"`
}
// DeleteSessionReq 删除整个会话请求
type DeleteSessionReq struct {
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
}
// DeleteSessionRes 删除整个会话响应
type DeleteSessionRes struct {
Ok bool `json:"ok" dc:"是否成功"`
}

View File

@@ -1,10 +1,11 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
NodeId string `orm:"node_id" json:"nodeId"`
RequestContent map[string]any `orm:"request_content" json:"requestContent"`
ResponseContent map[string]any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
@@ -13,6 +14,7 @@ type ComposeSession struct {
type composeSessionCol struct {
beans.SQLBaseCol
SessionId string
NodeId string
RequestContent string
ResponseContent string
Remark string
@@ -21,6 +23,7 @@ type composeSessionCol struct {
var ComposeSessionCol = composeSessionCol{
SQLBaseCol: beans.DefSQLBaseCol,
SessionId: "session_id",
NodeId: "node_id",
RequestContent: "request_content",
ResponseContent: "response_content",
Remark: "remark",

View File

@@ -1,6 +1,6 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
@@ -11,8 +11,7 @@ type ComposeTask struct {
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
GatewayState int `orm:"gateway_state" json:"gatewayState"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
ResultText map[string]any `orm:"result_text" json:"resultText"`
Messages map[string]any `orm:"messages" json:"messages"`
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
OssFile string `orm:"oss_file" json:"ossFile"`
@@ -28,8 +27,7 @@ type composeTaskCol struct {
CallbackUrl string
GatewayState string
RequestPayload string
ResultText string
Messages string
ResultJson string
Status string
ErrorMessage string
OssFile string
@@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{
CallbackUrl: "callback_url",
GatewayState: "gateway_state",
RequestPayload: "request_payload",
ResultText: "result_text",
Messages: "messages",
ResultJson: "result_json",
Status: "status",
ErrorMessage: "error_message",
OssFile: "oss_file",

View File

@@ -1,21 +1,21 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
// ProviderProtocol 模型协议映射配置
type ProviderProtocol struct {
beans.SQLBaseDO `orm:",inherit"`
// 业务字段
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
Capabilities any `orm:"capabilities" json:"capabilities"`
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
}
// providerProtocolCol 列名

View File

@@ -4,11 +4,14 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"prompts-core/common/util"
"prompts-core/model/entity"
"strings"
"gitea.com/red-future/common/beans"
commonHttp "gitea.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/beans"
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
@@ -56,8 +59,7 @@ type AsynchModel struct {
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
IsAsync *int `orm:"is_async" json:"isAsync"`
IsStream *int `orm:"is_stream" json:"isStream"`
CallModel int `orm:"call_model" json:"callModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
@@ -77,16 +79,28 @@ type AsynchModel struct {
// GetModelConfig 获取模型配置
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
fmt.Println("req参数", req)
fullURL := fmt.Sprintf("model-gateway/model/getModel?creator=%s&modelName=%s&isChatModel=%d",
req.Creator, req.ModelName, req.IsChatModel)
fullURL := "model-gateway/model/getModel"
// 拼接 query 参数
var params []string
if req.Creator != "" {
params = append(params, fmt.Sprintf("creator=%s", req.Creator))
}
if req.ModelName != "" {
params = append(params, fmt.Sprintf("modelName=%s", req.ModelName))
}
if req.IsChatModel != 0 {
params = append(params, fmt.Sprintf("isChatModel=%d", req.IsChatModel))
}
if len(params) > 0 {
fullURL += "?" + strings.Join(params, "&")
}
headers := util.ForwardHeaders(ctx)
var resp GetModelConfigResp
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
return nil, fmt.Errorf("获取模型配置失败: %w", err)
}
if resp.Model == nil {
return nil, fmt.Errorf("模型不存在: creator=%s modelName=%s isChatModel=%d", req.Creator, req.ModelName, req.IsChatModel)
return nil, fmt.Errorf("模型不存在")
}
return resp.Model, nil
}
@@ -134,78 +148,48 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
// SendCallbackReq 发送回调的请求体
type SendCallbackReq struct {
TaskId string `json:"taskId"`
Status string `json:"status"`
Messages *MultiRoundResult `json:"messages,omitempty"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
TaskId string `json:"taskId"`
Status string `json:"status"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
// SendCallback 向业务方发送回调
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
// 1. 检查回调地址
if composeTask.CallbackUrl == "" {
return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId)
}
// 2. 构造请求体
req := SendCallbackReq{
TaskId: composeTask.TaskId,
Status: composeTask.Status,
Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体
ErrorMsg: composeTask.ErrorMessage,
TaskId: composeTask.TaskId,
Status: composeTask.Status,
ErrorMsg: composeTask.ErrorMessage,
EpicycleId: epicycleId,
}
// 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx)
var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
composeTask.TaskId, composeTask.CallbackUrl, req.Messages)
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
composeTask.TaskId, composeTask.CallbackUrl)
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
}
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl)
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
return nil
}
// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult
func parseMessagesToResult(messages any) *MultiRoundResult {
if messages == nil {
return nil
// DownloadFile 从 OSS 下载文件内容
func DownloadFile(ossURL string) ([]byte, error) {
resp, err := http.Get(ossURL)
if err != nil {
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
}
var result MultiRoundResult
switch v := messages.(type) {
case *MultiRoundResult:
return v
case MultiRoundResult:
return &v
case string:
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil
}
case []byte:
if err := json.Unmarshal(v, &result); err != nil {
return nil
}
case map[string]any:
// 通过 JSON 序列化再反序列化
data, _ := json.Marshal(v)
if err := json.Unmarshal(data, &result); err != nil {
return nil
}
default:
data, err := json.Marshal(v)
if err != nil {
return nil
}
if err = json.Unmarshal(data, &result); err != nil {
return nil
}
}
return &result
return io.ReadAll(resp.Body)
}

View File

@@ -2,9 +2,7 @@ package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"prompts-core/service/gateway"
"strings"
@@ -13,92 +11,69 @@ import (
"prompts-core/model/dto"
"prompts-core/model/entity"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// UserPromptPayload 用户提示词请求体
type UserPromptPayload struct {
Model string `json:"model"`
PromptInfo string `json:"promptInfo"`
Form any `json:"form"`
UserForm any `json:"userForm"`
Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"`
Skills string `json:"skills"`
BuildType int `json:"buildType"`
}
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt)
//2) 构建历史对话
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
//2) 检查整体内容是否超出窗口
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
return compileToProviderRequest(ctx, ir, chatModel)
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
return nil, err
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
return nil, fmt.Errorf("协议配置不存在或获取失败")
}
// 如果传了自定义提示词,替换掉协议模板
if len(customPrompt) > 0 && customPrompt[0] != "" {
protocol.SystemPromptTemplate = customPrompt[0] +
"【核心铁律】" +
"1.【技能内容skill相关】必须完整拼接到System提示词中作为System提示词的组成部分不得拆分到其他位置。"
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": util.GetServerName(ctx),
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
"buildType": req.BuildType,
}, nil
}
// promptBuildWithRounds 构建系统提示词
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
// promptBuildWithRounds 构建提示词
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
Status: 1,
@@ -106,89 +81,71 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, cha
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
formContent := buildUserFormContent(req.Form)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, formContent, userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.Consult)
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName, // %s 目标模型名称
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
outputJSON, //【输出结构】 %s
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
payload := UserPromptPayload{
Model: req.ModelName,
PromptInfo: prompt,
Form: prepareUserFormPayload(req.Form),
UserForm: prepareUserFormPayload(req.UserForm),
Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult),
Skills: SkillMdContent(ctx, req.SkillName),
BuildType: req.BuildType,
var b strings.Builder
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
if prompt != "" {
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
}
return gjson.New(payload).String()
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
}
if formText := buildUserFormText(req.Form); formText != "" {
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
}
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
}
if len(req.Consult) > 0 {
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
}
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
}
return b.String()
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
func buildUserFormText(form []map[string]any) string {
if len(form) == 0 {
return ""
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
for _, item := range form {
for k, v := range item {
builder.WriteString(fmt.Sprintf("%s\n", k))
switch val := v.(type) {
case []any:
for i, elem := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
if m, ok := elem.(map[string]any); ok {
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
} else {
builder.WriteString(fmt.Sprint(elem))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf(" %v\n", v))
}
}
builder.WriteString(text)
}
return builder.String()
return strings.TrimSpace(builder.String())
}
// NodeBuild 节点构建
@@ -197,9 +154,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
return fmt.Sprintf(promptTpl,
gjson.New(req.Form).MustToJsonString(),
gjson.New(req.UserForm).MustToJsonString(),
)
}

View File

@@ -2,43 +2,36 @@ package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/consts/public"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
//1) 获取模型信息
// 1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
//2) 校验用户表单
// 2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
//3) 处理不同类型
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
case public.BuildTypeNode:
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
default:
return nil, errors.New("BuildType 不支持")
}
return handleBuild(ctx, req, chatModel, aiModel)
}
// GetModelMessage 获取模型信息
@@ -51,24 +44,19 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
if err != nil || chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
if aiModels == nil {
if err != nil || aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModels, nil
return chatModel, aiModel, nil
}
// validateUserForm 校验用户表单
@@ -89,103 +77,115 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := session.GetHistoryMessages(ctx, req.SessionId)
// handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
// 调用推理模型
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
// 2) 构建推理请求
ir := NewPromptIR()
var taskReq map[string]any
switch req.BuildType {
case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct:
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
if err != nil {
return nil, fmt.Errorf("构建推理请求失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id := int64(0)
if req.SessionId != "" {
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
}
// 3) 调用网关创建任务
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
return nil, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", 0, errors.New("网关未返回taskId")
return nil, errors.New("网关未返回taskId")
}
return taskID, id, nil
// 4) 保存任务记录
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: gconv.Map(req),
Status: public.ComposeStatusPending,
}); err != nil {
return nil, err
}
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
// 1) 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
// 2) 读取 OSS 文件内容
var ossContent []byte
if req.OssFile != "" {
ossContent, err = gateway.DownloadFile(req.OssFile)
if err != nil {
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
}
}
// 3) 解析 OSS 内容为消息
var messages map[string]any
if len(ossContent) > 0 {
messages, _ = gjson.New(ossContent).Map(), nil
}
// 4) 处理失败
if req.State == 3 {
return handleCallbackFailed(ctx, req, composeTask, messages)
}
// 5) 处理成功
if req.State == 2 {
return handleCallbackSuccess(ctx, req, composeTask, messages)
}
return nil
}
// handleCallbackFailed 处理回调失败
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: messages,
})
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusFailed
composeTask.ErrorMessage = req.ErrorMsg
_ = gateway.SendCallback(ctx, composeTask, 0)
}
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
// 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
@@ -193,133 +193,125 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if err != nil {
return fmt.Errorf("查询模型失败: %w", err)
}
//处理失败
if req.State == 3 {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
// 2) 获取协议配置
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
// 3) 获取历史消息 + 保存当前轮
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage
var epicycleId int64
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
// 3.1 获取历史
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
failedTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
CallbackUrl: composeTask.CallbackUrl,
Messages: composeTask.Messages,
}
gateway.SendCallback(ctx, failedTask)
if h != nil {
history = h.Messages
}
// 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 4) 合并附加结构
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 5) 注入历史
if len(history) > 0 {
messages = InjectHistory(messages, history, protocol)
}
// 6) 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultJson: messages,
})
if err != nil {
return err
}
//处理成功
if req.State == 2 {
// 1. 根据 BuildType 解析结果
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Messages)
default:
messages = req.Messages
}
// 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
// 3. 更新数据库
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
return err
}
// 4. 发送回调给业务方
if composeTask.CallbackUrl != "" {
successTask := &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
CallbackUrl: composeTask.CallbackUrl,
// 8) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
composeTask.ResultJson = messages
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
}
return nil
}
// InjectHistory 插入历史会话
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 {
return roundsData
}
// 1) 提取第一轮的 messages
rounds := roundsData["rounds"].([]any)
firstRound := rounds[0].(map[string]any)
original := firstRound["messages"].([]any)
// 2) 按 merge_order 拼接
result := make([]any, 0, len(original)+len(history))
for _, part := range protocol.MergeOrder {
switch part {
case "system":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "system" {
result = append(result, msg)
}
}
gateway.SendCallback(ctx, successTask)
}
}
return err
}
// ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if !ok || contentStr == "" {
return raw
}
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return map[string]any{
"total_rounds": len(roundsArray),
"rounds": roundsArray,
}
}
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{singleRound},
}
}
return map[string]any{"content": contentStr}
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
func ParseNodeResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{inner},
case "history":
if gconv.Bool(protocol.Capabilities["support_history"]) {
for _, msg := range history {
result = append(result, map[string]any{
"role": msg.Role,
"content": msg.Content, // 纯字符串,不转换
})
}
}
case "user":
for _, m := range original {
msg := m.(map[string]any)
if gconv.String(msg["role"]) == "user" {
result = append(result, msg)
}
}
}
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{raw},
// 3) 角色映射
if len(protocol.RoleMapping) > 0 {
for _, m := range result {
msg := m.(map[string]any)
role := gconv.String(msg["role"])
if mapped, ok := protocol.RoleMapping[role]; ok {
msg["role"] = mapped
}
}
}
// 4) 直接修改原对象
firstRound["messages"] = result
return roundsData
}
// GetComposeTask 查询任务结果
@@ -330,31 +322,10 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.Messages)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
Messages: record.ResultJson,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}

View File

@@ -22,26 +22,25 @@ const (
bytesPerMB = 1024 * 1024
)
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) map[string]string {
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
urls := make([]string, 0, len(consult))
for _, item := range consult {
if item.Url != "" {
urls = append(urls, item.Url)
}
}
return FetchFileTexts(ctx, urls)
return FetchFileTextsAsString(ctx, urls)
}
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string)
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
if len(urls) == 0 {
return result
return ""
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
var builder strings.Builder
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
@@ -50,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
}
if util.IsZipExtension(url) {
mergeMap(result, fetchZipFileTexts(ctx, client, url))
for _, text := range fetchZipFileTexts(ctx, client, url) {
builder.WriteString(text)
builder.WriteString("\n")
}
continue
}
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
result[url] = text
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
}
}
return result
}
// mergeMap 合并 map
func mergeMap(dst, src map[string]string) {
for k, v := range src {
dst[k] = v
}
return builder.String()
}
// fetchAndCleanFileContent 获取并清理文件内容
@@ -195,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
}
func SkillMdContent(ctx context.Context, skillName string) string {
if skillName == "" {
return ""
}
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)

View File

@@ -2,18 +2,18 @@ package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
// IR 统一 Prompt 中间表示
type IR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
@@ -34,6 +34,7 @@ type ProviderProtocol struct {
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
SystemPromptTemplate string `json:"system_prompt_template"`
Capabilities map[string]any `json:"capabilities"`
}
// ContentMapping 内容字段映射
@@ -43,8 +44,8 @@ type ContentMapping struct {
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
func NewPromptIR() *IR {
return &IR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
@@ -52,7 +53,7 @@ func NewPromptIR() *PromptIR {
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
func (ir *IR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -78,7 +79,7 @@ func (ir *PromptIR) String() string {
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
func (ir *IR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -100,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
func (ir *IR) AddSystem(content string) *IR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
@@ -108,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
func (ir *IR) AddUser(content string) *IR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
@@ -116,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *IR) AddHistory(role, content string) *IR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
@@ -124,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
func (ir *IR) ToMessages() []map[string]any {
var messages []map[string]any
for _, seg := range ir.System {
@@ -165,21 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
return &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
MergeOrder: e.MergeOrder,
RoleMapping: gconv.MapStrStr(e.RoleMapping),
ContentMapping: ContentMapping{
Type: gconv.String(e.ContentMapping["type"]),
Field: gconv.String(e.ContentMapping["field"]),
},
RequestTemplate: e.RequestTemplate,
Capabilities: e.Capabilities,
}
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
return p
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -191,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
func mergeByOrder(ir *IR, order []string) []map[string]any {
roleMap := map[string][]Segment{
"system": ir.System,
"history": ir.History,
"user": ir.User,
}
var messages []map[string]any
for _, part := range order {
for _, seg := range roleMap[part] {
msg := map[string]any{"content": seg.Content}
if part == "history" {
msg["role"] = seg.Role
} else {
msg["role"] = part
}
messages = append(messages, msg)
}
}
return messages
}
@@ -243,29 +235,29 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
msg[cm.Field] = content
}
if cm.Field == "" || cm.Field == "content" {
return messages
}
for i, msg := range messages {
if content, ok := msg["content"]; ok {
delete(msg, "content")
switch cm.Type {
case "parts":
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
default:
messages[i][cm.Field] = content
}
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
return renderTemplate(p, messages, chatModel)
}
return map[string]any{
@@ -273,20 +265,21 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// renderTemplate 模板渲染
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
result := make(map[string]any, len(p.RequestTemplate)+1)
for k, v := range p.RequestTemplate {
result[k] = v
}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
if chatModel != nil {
result["model"] = chatModel.ModelName
}
result["messages"] = messages
var result map[string]any
json.Unmarshal([]byte(str), &result)
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens
}
return result
}

View File

@@ -4,139 +4,148 @@ import (
"context"
"encoding/json"
"fmt"
"prompts-core/model/entity"
"prompts-core/common/util"
"prompts-core/model/dto"
"time"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
const (
redisKeyPrefix = "chat:session:%s"
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId}
RedisKeySessionHistory = "session:history:%d:%s:%s"
)
// formatRedisKey 格式化Redis
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
// formatRedisKey 格式化 Redis key
func formatRedisKey(tenantID uint64, sessionID, nodeID string) string {
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID)
}
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
key := formatRedisKey(session.SessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": session.SessionId,
"requestContent": session.RequestContent,
"responseContent": session.ResponseContent,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data)
// ============================================
// 写操作
// ============================================
// SaveToRedis 保存一轮对话到 Redis ZSET
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error {
key := formatRedisKey(tenantID, sessionID, nodeID)
maxRounds := util.GetMaxRounds(ctx)
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
b, err := json.Marshal(round)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
score := float64(time.Now().UnixMilli())
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
return fmt.Errorf("ZADD失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
return fmt.Errorf("裁剪失败: %w", err)
}
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期失败: %w", err)
}
return nil
}
// DeleteSessionHistory 删除整个 session 下所有 node 的缓存
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*")
keys, err := g.Redis().Do(ctx, "KEYS", pattern)
if err != nil {
return err
}
for _, key := range keys.Strings() {
_, _ = g.Redis().Do(ctx, "DEL", key)
}
return nil
}
// executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
// DeleteRedisMessages 批量删除指定 node 下的消息
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error {
key := formatRedisKey(tenantID, sessionID, nodeID)
for _, msgID := range msgIDs {
cursor := "0"
for {
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
if err != nil {
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
break
}
parts := result.Strings()
if len(parts) < 2 {
break
}
cursor = parts[0]
for _, member := range parts[1:] {
_, _ = g.Redis().Do(ctx, "ZREM", key, member)
}
if cursor == "0" {
break
}
}
}
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := formatRedisKey(sessionId)
// ============================================
// 读操作
// ============================================
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
// GetFromRedis 从 Redis ZSET 获取会话历史
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) {
key := formatRedisKey(tenantID, sessionID, nodeID)
maxRounds := util.GetMaxRounds(ctx)
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
}
if result == nil || result.IsNil() {
return []map[string]any{}, nil
return []dto.HistoryRound{}, nil
}
sessions := parseRedisSessions(ctx, result.Strings())
reverseSlice(sessions)
return sessions, nil
return parseRounds(result.Strings()), nil
}
// parseRedisSessions 解析Redis会话数据
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
var sessions []map[string]any
// ============================================
// 解析
// ============================================
for _, str := range values {
var data map[string]any
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
func parseRounds(members []string) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(members))
for _, member := range members {
var round dto.HistoryRound
if err := json.Unmarshal([]byte(member), &round); err != nil {
continue
}
sessions = append(sessions, data)
}
return sessions
}
// reverseSlice 反转切片
func reverseSlice(s []map[string]any) {
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
s[i], s[j] = s[j], s[i]
}
}
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
if len(historyData) == 0 {
return []map[string]any{}, nil
}
return flattenHistoryMessages(historyData), nil
}
// flattenHistoryMessages 扁平化历史消息
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
var messages []map[string]any
for _, round := range historyData {
appendMessagesFromField(round, "requestContent", &messages)
appendMessagesFromField(round, "responseContent", &messages)
}
return messages
}
// appendMessagesFromField 从指定字段追加消息
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
msgs, ok := data[field].([]interface{})
if !ok {
return
}
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
*messages = append(*messages, msg)
if round.User != nil || round.Assistant != nil {
rounds = append(rounds, round)
}
}
return rounds
}
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
var messages []dto.FlatMessage
for i := len(rounds) - 1; i >= 0; i-- {
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].User["role"]),
Content: gconv.String(rounds[i].User["content"]),
})
}
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
messages = append(messages, dto.FlatMessage{
Role: gconv.String(rounds[i].Assistant["role"]),
Content: gconv.String(rounds[i].Assistant["content"]),
})
}
}
return messages
}

View File

@@ -4,7 +4,8 @@ import (
"context"
"fmt"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -14,9 +15,14 @@ import (
"prompts-core/model/entity"
)
// ============================================
// 回调存储
// ============================================
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
req.Messages["role"] = "assistant"
// 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: req.Messages,
@@ -25,103 +31,161 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("更新数据库失败: %w", err)
}
// 2) 查询完整记录
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if session == nil {
if err != nil || session == nil {
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err)
}
if err = saveToRedis(ctx, session); err != nil {
// 3) entity → HistoryRound → 写入 Redis
round := entityToHistoryRound(session)
round.Assistant = req.Messages
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
}
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
return &dto.SessionCallbackRes{
Status: true,
SessionId: session.SessionId,
}, nil
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
}
// GetHistoryMessages 获取历史信息
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// ============================================
// 场景1前端历史列表按 creator
// ============================================
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 {
return redisHistory, nil
// GetHistoryList 获取历史列表
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
}, req.Page, req.Size)
if err != nil {
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
}
rounds := sessionsToHistoryRounds(sessions)
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
}
// ============================================
// 场景2提示词拼接按 sessionId + nodeId
// ============================================
// GetHistoryMessages 获取历史消息Redis → DB → 异步回种)
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
}
// 1) Redis
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 {
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// getHistoryFromDatabase 从数据库获取历史记录
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
// 2) DB
maxRounds := util.GetMaxRounds(ctx)
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
SessionId: req.SessionId,
NodeId: req.NodeId,
}, 1, maxRounds)
if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err)
}
messages := extractMessagesFromSessions(sessions)
cacheSessionsToRedis(ctx, sessions)
return messages, nil
}
// extractMessagesFromSessions 从会话列表中提取消息
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
var messages []map[string]any
for _, session := range sessions {
appendRequestMessages(session.RequestContent, &messages)
appendResponseMessages(session.ResponseContent, &messages)
if len(sessions) == 0 {
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
}
return messages
// 3) 转换 + 异步回种
rounds := sessionsToHistoryRounds(sessions)
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds)
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
}
// appendRequestMessages 追加请求消息
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
reqMsgs := util.ConvertToMessages(requestContent)
for _, m := range reqMsgs {
role := gconv.String(m["role"])
if role == "user" || role == "assistant" {
*messages = append(*messages, m)
}
// ============================================
// 删除
// ============================================
// DeleteMessages 删除消息
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
if len(req.MsgIds) == 0 {
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
}
user, _ := utils.GetUserInfo(ctx)
// 1) 批量查询
sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
// 2) 批量删 DB
_, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
// 3) 按 nodeId 分组删 Redis
for _, s := range sessions {
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds)
}
return &dto.DeleteMessagesRes{Ok: true}, nil
}
// DeleteSession 删除整个会话
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
// 1) 删 DB
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
}); err != nil {
return nil, fmt.Errorf("DB删除失败: %w", err)
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 2) 删 Redis
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
}
return &dto.DeleteSessionRes{Ok: true}, nil
}
// ============================================
// 转换方法entity ↔ dto集中管理
// ============================================
// entityToHistoryRound entity → HistoryRound
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
return &dto.HistoryRound{
Id: s.Id,
SessionId: s.SessionId,
NodeId: s.NodeId,
CreatedAt: gconv.String(s.CreatedAt),
UpdatedAt: gconv.String(s.UpdatedAt),
User: s.RequestContent,
Assistant: s.ResponseContent,
}
}
// appendResponseMessages 追加响应消息
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
respMsgs := util.ConvertToMessages(responseContent)
for _, m := range respMsgs {
if m["role"] == nil {
m["role"] = "assistant"
}
*messages = append(*messages, m)
// sessionsToHistoryRounds 批量转换
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
rounds := make([]dto.HistoryRound, 0, len(sessions))
for _, s := range sessions {
rounds = append(rounds, *entityToHistoryRound(s))
}
return rounds
}
// cacheSessionsToRedis 将会话缓存到Redis
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
for _, session := range sessions {
reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs {
if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant"
}
}
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session)
// asyncCacheToRedis 异步缓存到 Redis
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) {
for i := range rounds {
if rounds[i].User != nil || rounds[i].Assistant != nil {
_ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i])
}
}
}