Compare commits
10 Commits
aa7804656f
...
dev未优化
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d52b631b9 | |||
| c22d578e1a | |||
| df26329836 | |||
| 40abf0f606 | |||
| b69e7386e2 | |||
| 1c1db7e30c | |||
| 78114f99c7 | |||
| 9410199fbe | |||
| 1f9a2b9b5f | |||
| e1461cf0f0 |
28
Dockerfile
28
Dockerfile
@@ -1,43 +1,23 @@
|
||||
# 多阶段构建 - 第一阶段:编译(使用已安装的镜像)
|
||||
# 阶段1: 构建
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
ENV GO111MODULE=on
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
ENV CGO_ENABLED=0
|
||||
ENV GOTOOLCHAIN=auto
|
||||
ENV GOPRIVATE=gitea.com/red-future/common
|
||||
|
||||
# 配置git使用私有Gitea仓库(带Token认证)
|
||||
RUN git config --global url."http://x-token-auth:619679cd366aefea3a50f0622d842a41f2209e08595767bba49c3836ef57d415@116.204.74.41:3000/red-future/common.git".insteadOf "https://gitea.com/red-future/common.git" && \
|
||||
git config --global credential.helper store
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# 复制父目录的 common 模块(因为 go.mod 中使用了本地 replace)
|
||||
#COPY ../common /build/common
|
||||
COPY . .
|
||||
|
||||
RUN go mod download && go mod tidy
|
||||
|
||||
RUN go build -ldflags="-s -w" -o main ./main.go
|
||||
|
||||
# 第二阶段:运行
|
||||
FROM alpine:3.19
|
||||
|
||||
ENV TIME_ZONE=Asia/Shanghai
|
||||
RUN apk add --no-cache ca-certificates tzdata && \
|
||||
ln -sf /usr/share/zoneinfo/$TIME_ZONE /etc/localtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制编译好的二进制文件
|
||||
COPY --from=builder /build/main .
|
||||
COPY --from=builder /build/config.yml ./
|
||||
|
||||
# 创建日志目录
|
||||
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
|
||||
|
||||
EXPOSE 3009
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
@@ -13,16 +12,6 @@ func GetServerName(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||
}
|
||||
|
||||
// GetServerPort 从配置获取服务端口
|
||||
func GetServerPort(ctx context.Context) string {
|
||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||
// address 格式如 ":3009",去掉冒号
|
||||
if strings.HasPrefix(address, ":") {
|
||||
return address[1:]
|
||||
}
|
||||
return "8080"
|
||||
}
|
||||
|
||||
// GetModelPrompt 获取请求模型的提示词
|
||||
func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
key := "modelPrompts.types." + gconv.String(modelType)
|
||||
@@ -33,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string {
|
||||
func GetBuildPrompt(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||
}
|
||||
|
||||
// GetMaxRounds 获取最大轮数配置
|
||||
func GetMaxRounds(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
}
|
||||
|
||||
// GetExpireMinutes 获取过期时间配置
|
||||
func GetExpireMinutes(ctx context.Context) int {
|
||||
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package util
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,233 +1,81 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
tGjson "github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertToMessages 将原始数据转换为消息列表
|
||||
func ConvertToMessages(raw any) []map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
j := gfgjson.New(raw)
|
||||
messages := j.Get("messages")
|
||||
if !messages.IsNil() {
|
||||
return gconv.Maps(messages.Val())
|
||||
}
|
||||
return []map[string]any{j.Map()}
|
||||
}
|
||||
|
||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||
func FormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "[]"
|
||||
}
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
||||
func UserFormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MustMarshalToMap 将对象序列化为 map[string]any,失败时返回空 map
|
||||
func MustMarshalToMap(v any) map[string]any {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return make(map[string]any)
|
||||
}
|
||||
var m map[string]any
|
||||
json.Unmarshal(b, &m)
|
||||
return m
|
||||
}
|
||||
|
||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
||||
func JSONPretty(v any) string {
|
||||
if gv, ok := v.(*gvar.Var); ok {
|
||||
v = gconv.Map(gv.String())
|
||||
}
|
||||
|
||||
var tmp map[string]any
|
||||
if err := gconv.Struct(v, &tmp); err != nil {
|
||||
return gconv.String(v)
|
||||
}
|
||||
|
||||
b, _ := json.MarshalIndent(tmp, "", " ")
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
||||
func ParseJSONFieldFromGvar(source any, target any) {
|
||||
if source == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := source.(type) {
|
||||
case *gvar.Var:
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试获取 map
|
||||
if m := v.Map(); len(m) > 0 {
|
||||
data, _ := json.Marshal(m)
|
||||
json.Unmarshal(data, target)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析 JSON 字符串
|
||||
str := v.String()
|
||||
if str != "" && str != "<nil>" {
|
||||
json.Unmarshal([]byte(str), target)
|
||||
}
|
||||
|
||||
default:
|
||||
// 其他类型走原来的逻辑
|
||||
data, _ := json.Marshal(source)
|
||||
json.Unmarshal(data, target)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中。
|
||||
//
|
||||
// 参数说明:
|
||||
// - req: 请求参数 map,需包含 "consult" 字段,值为 []any,每个元素是 {"type":"xxx","url":"..."}
|
||||
// - messages: 模型生成的返回结构(如 rounds[...].messages[...].content 数组)
|
||||
// - extendMapping: 附加映射配置,格式:
|
||||
// {"attachments": {"image": {"template": {...}, "target_path": "...", "field_mapping": {...}}, ...}}
|
||||
//
|
||||
// 返回值:合并后的完整 map。
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
reqJSON, _ := json.Marshal(req)
|
||||
msgJSON, _ := json.Marshal(messages)
|
||||
extJSON, _ := json.Marshal(extendMapping)
|
||||
|
||||
reqStr := string(reqJSON)
|
||||
msgStr := string(msgJSON)
|
||||
extStr := string(extJSON)
|
||||
|
||||
// 获取 consult 数组
|
||||
consultResult := tGjson.Get(reqStr, "consult")
|
||||
if !consultResult.Exists() || !consultResult.IsArray() {
|
||||
consult := gconv.Interfaces(req["consult"])
|
||||
if len(consult) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 获取 attachments 配置
|
||||
attachmentsResult := tGjson.Get(extStr, "attachments")
|
||||
if !attachmentsResult.Exists() || !attachmentsResult.IsObject() {
|
||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||
if targetPath == "" || len(templates) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
consultArr := consultResult.Array()
|
||||
attachmentsMap := attachmentsResult.Map()
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
for _, consultItem := range consultArr {
|
||||
if !consultItem.IsObject() {
|
||||
// rounds 路径修正
|
||||
if !msgJson.Get("rounds.0").IsNil() {
|
||||
targetPath = "rounds.0." + targetPath
|
||||
}
|
||||
|
||||
// 遍历追加
|
||||
for _, item := range consult {
|
||||
itemJson := gjson.New(item)
|
||||
itemType := itemJson.Get("type").String()
|
||||
tmpl := gconv.Map(templates[itemType])
|
||||
if itemType == "" || len(tmpl) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
itemType := consultItem.Get("type").String()
|
||||
if itemType == "" {
|
||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||
if attachment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查找对应类型的附件配置
|
||||
attachResult, ok := attachmentsMap[itemType]
|
||||
if !ok || !attachResult.IsObject() {
|
||||
continue
|
||||
}
|
||||
idx := len(msgJson.Get(targetPath).Array())
|
||||
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
||||
}
|
||||
|
||||
// 获取模板
|
||||
templateResult := attachResult.Get("template")
|
||||
if !templateResult.Exists() || !templateResult.IsObject() {
|
||||
continue
|
||||
}
|
||||
return msgJson.Map()
|
||||
}
|
||||
|
||||
// 深拷贝模板
|
||||
filledTemplateStr := templateResult.Raw
|
||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||
typ := gconv.String(tmpl["type"])
|
||||
if typ == "" || url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用字段映射
|
||||
fieldMappingResult := attachResult.Get("field_mapping")
|
||||
if fieldMappingResult.Exists() && fieldMappingResult.IsObject() {
|
||||
fieldMapping := fieldMappingResult.Map()
|
||||
for fieldPath, valueSource := range fieldMapping {
|
||||
sourceKey := valueSource.String()
|
||||
valueResult := consultItem.Get(sourceKey)
|
||||
if valueResult.Exists() {
|
||||
var err error
|
||||
filledTemplateStr, err = sjson.SetRaw(filledTemplateStr, fieldPath, valueResult.Raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
body := gconv.Map(tmpl["body"])
|
||||
fillEmptyInPlace(body, url)
|
||||
|
||||
return map[string]any{
|
||||
"type": typ,
|
||||
typ: body,
|
||||
}
|
||||
}
|
||||
|
||||
func fillEmptyInPlace(m map[string]any, value string) {
|
||||
for k, v := range m {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
if vv == "" {
|
||||
m[k] = value
|
||||
}
|
||||
}
|
||||
|
||||
// 获取目标路径
|
||||
targetPath := attachResult.Get("target_path").String()
|
||||
if targetPath == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目标路径是否存在且为数组
|
||||
targetResult := tGjson.Get(msgStr, targetPath)
|
||||
if !targetResult.Exists() || !targetResult.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 追加到数组末尾
|
||||
arrLen := len(targetResult.Array())
|
||||
appendPath := targetPath + "." + strconv.Itoa(arrLen)
|
||||
var err error
|
||||
msgStr, err = sjson.SetRaw(msgStr, appendPath, filledTemplateStr)
|
||||
if err != nil {
|
||||
continue
|
||||
case map[string]any:
|
||||
fillEmptyInPlace(vv, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 转回 map[string]any
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(msgStr), &result); err != nil {
|
||||
return messages
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUserMessage 获取用户消息
|
||||
func GetUserMessage(taskReq map[string]any) map[string]any {
|
||||
// 先取 requestPayload
|
||||
rp, ok := taskReq["requestPayload"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
// 再取 messages
|
||||
messages, ok := rp["messages"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
for _, msg := range messages {
|
||||
m, ok := msg.(map[string]any)
|
||||
if ok && m["role"] == "user" {
|
||||
return m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
@@ -22,86 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||
func ExtractUserText(messages map[string]any) map[string]any {
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
msgs := msgJson.Get("rounds.0.messages")
|
||||
if msgs.IsNil() {
|
||||
msgs = msgJson.Get("messages")
|
||||
}
|
||||
|
||||
responseJson := gjson.New(responseBytes)
|
||||
resultJson := gjson.New("{}")
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
var texts []string
|
||||
for _, m := range msgs.Array() {
|
||||
msg := gjson.New(m)
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
val := responseJson.Get(path)
|
||||
if val.IsNil() {
|
||||
continue
|
||||
}
|
||||
resultJson.Set(standardField, val.Val())
|
||||
}
|
||||
|
||||
return []byte(resultJson.String()), nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
content := msg.Get("content").Val()
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
texts = append(texts, c)
|
||||
case []any:
|
||||
for _, item := range c {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t := gconv.String(m["text"]); t != "" {
|
||||
texts = append(texts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PayloadToQuery 将 payload 转为 url.Values
|
||||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
q.Set(k, gconv.String(v))
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": strings.Join(texts, "\n"),
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
@@ -50,14 +50,14 @@ database:
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: 116.204.74.41:6379
|
||||
address: 192.168.3.30:6379
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 116.204.74.41:8500
|
||||
address: 192.168.3.30:8500
|
||||
|
||||
jaeger:
|
||||
addr: 116.204.74.41:4318
|
||||
addr: 192.168.3.30:4318
|
||||
|
||||
task:
|
||||
waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒)
|
||||
|
||||
@@ -11,3 +11,7 @@ const (
|
||||
BuildTypeNode = 2 //节点构建
|
||||
BuildTypeStruct = 3 //结构构建
|
||||
)
|
||||
|
||||
const (
|
||||
ModelTypeInference = 100 // 推理模型
|
||||
)
|
||||
|
||||
@@ -1,18 +1,36 @@
|
||||
// ============================================
|
||||
// controller/session.go
|
||||
// ============================================
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"prompts-core/model/dto"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
sessionService "prompts-core/service/session"
|
||||
)
|
||||
|
||||
type session struct{}
|
||||
|
||||
// Session 提示词会话控制器
|
||||
var Session = new(session)
|
||||
|
||||
// SessionCallback 会话回调
|
||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||
return sessionService.Callback(ctx, req)
|
||||
}
|
||||
|
||||
// GetHistoryList 获取历史列表(前端列表)
|
||||
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
|
||||
return sessionService.GetHistoryList(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteMessages 批量删除消息
|
||||
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
|
||||
return sessionService.DeleteMessages(ctx, req)
|
||||
}
|
||||
|
||||
// DeleteSession 删除整个会话
|
||||
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
||||
return sessionService.DeleteSession(ctx, req)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
)
|
||||
|
||||
var ComposeSession = &composeSessionDao{}
|
||||
@@ -15,13 +14,8 @@ type composeSessionDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||
var m = new(entity.ComposeSession)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
Insert(m)
|
||||
Insert(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -69,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
@@ -86,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
OmitEmpty().
|
||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||
Delete()
|
||||
if err != nil {
|
||||
@@ -93,3 +89,36 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// ListByIds 根据 ID 列表批量查询
|
||||
func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) {
|
||||
if len(ids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||
Where(entity.ComposeSessionCol.Creator, creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteByIds 批量删除编排会话
|
||||
func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||
Where(entity.ComposeSessionCol.Creator, creator).
|
||||
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"prompts-core/consts/public"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
|
||||
6
go.mod
6
go.mod
@@ -3,12 +3,10 @@ module prompts-core
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
gitea.com/red-future/common v0.0.20
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
|
||||
github.com/gogf/gf/v2 v2.10.2
|
||||
github.com/tidwall/gjson v1.19.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -65,8 +63,6 @@ require (
|
||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||
|
||||
13
go.sum
13
go.sum
@@ -1,6 +1,6 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
gitea.com/red-future/common v0.0.20 h1:KlKINnJFmOVkDzgkptEAFsdpMUZb0zK9BTdiXRxVfAo=
|
||||
gitea.com/red-future/common v0.0.20/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -288,15 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
||||
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
|
||||
6
main.go
6
main.go
@@ -7,9 +7,9 @@ import (
|
||||
"prompts-core/controller"
|
||||
"syscall"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
_ "gitea.com/red-future/common/swagger"
|
||||
"gitea.redpowerfuture.com/red-future/common/http"
|
||||
"gitea.redpowerfuture.com/red-future/common/jaeger"
|
||||
_ "gitea.redpowerfuture.com/red-future/common/swagger"
|
||||
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
|
||||
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
@@ -25,21 +25,13 @@ type ComposeMessagesRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
}
|
||||
|
||||
// MultiRoundResult 多轮返回结果
|
||||
type MultiRoundResult struct {
|
||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
State int `json:"state" dc:"网关任务状态"`
|
||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||
}
|
||||
|
||||
type CallbackRes struct {
|
||||
@@ -51,11 +43,11 @@ type GetComposeTaskReq struct {
|
||||
}
|
||||
|
||||
type GetComposeTaskRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"业务状态"`
|
||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||
Messages any `json:"messages" dc:"最终消息数组"`
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"业务状态"`
|
||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||
Messages map[string]any `json:"messages" dc:"最终消息数组"`
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
}
|
||||
|
||||
@@ -2,13 +2,79 @@ package dto
|
||||
|
||||
import "github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
type SessionCallbackReq struct {
|
||||
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
|
||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
// HistoryRound 一轮对话
|
||||
type HistoryRound struct {
|
||||
Id int64 `json:"id" dc:"记录ID"`
|
||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
User map[string]any `json:"user" dc:"用户消息"`
|
||||
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
||||
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
|
||||
// SessionCallbackReq 会话回调请求
|
||||
type SessionCallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
||||
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
|
||||
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
|
||||
}
|
||||
|
||||
// SessionCallbackRes 会话回调响应
|
||||
type SessionCallbackRes struct {
|
||||
Status bool `json:"status" dc:"状态"`
|
||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||
}
|
||||
|
||||
// GetHistoryListReq 获取历史列表请求(前端)
|
||||
type GetHistoryListReq struct {
|
||||
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
|
||||
Page int `json:"page" d:"1" dc:"页码"`
|
||||
Size int `json:"size" d:"10" dc:"每页条数"`
|
||||
}
|
||||
|
||||
// GetHistoryListRes 获取历史列表响应
|
||||
type GetHistoryListRes struct {
|
||||
List []HistoryRound `json:"list" dc:"历史列表"`
|
||||
Total int `json:"total" dc:"总数"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
|
||||
type GetHistoryMessagesReq struct {
|
||||
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||
}
|
||||
|
||||
// GetHistoryMessagesRes 获取历史消息响应
|
||||
type GetHistoryMessagesRes struct {
|
||||
Messages []FlatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type FlatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// DeleteMessagesReq 批量删除消息请求
|
||||
type DeleteMessagesReq struct {
|
||||
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
|
||||
}
|
||||
|
||||
// DeleteMessagesRes 批量删除消息响应
|
||||
type DeleteMessagesRes struct {
|
||||
Ok bool `json:"ok" dc:"是否成功"`
|
||||
}
|
||||
|
||||
// DeleteSessionReq 删除整个会话请求
|
||||
type DeleteSessionReq struct {
|
||||
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
|
||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||
}
|
||||
|
||||
// DeleteSessionRes 删除整个会话响应
|
||||
type DeleteSessionRes struct {
|
||||
Ok bool `json:"ok" dc:"是否成功"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
type ComposeSession struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
type ComposeTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
@@ -11,8 +11,7 @@ type ComposeTask struct {
|
||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
||||
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
||||
ResultText map[string]any `orm:"result_text" json:"resultText"`
|
||||
Messages map[string]any `orm:"messages" json:"messages"`
|
||||
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
|
||||
Status string `orm:"status" json:"status"`
|
||||
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||
@@ -28,8 +27,7 @@ type composeTaskCol struct {
|
||||
CallbackUrl string
|
||||
GatewayState string
|
||||
RequestPayload string
|
||||
ResultText string
|
||||
Messages string
|
||||
ResultJson string
|
||||
Status string
|
||||
ErrorMessage string
|
||||
OssFile string
|
||||
@@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{
|
||||
CallbackUrl: "callback_url",
|
||||
GatewayState: "gateway_state",
|
||||
RequestPayload: "request_payload",
|
||||
ResultText: "result_text",
|
||||
Messages: "messages",
|
||||
ResultJson: "result_json",
|
||||
Status: "status",
|
||||
ErrorMessage: "error_message",
|
||||
OssFile: "oss_file",
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
// ProviderProtocol 模型协议映射配置
|
||||
type ProviderProtocol struct {
|
||||
beans.SQLBaseDO `orm:",inherit"`
|
||||
// 业务字段
|
||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||
TargetField string `orm:"target_field" json:"targetField"`
|
||||
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
|
||||
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
|
||||
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
|
||||
Capabilities any `orm:"capabilities" json:"capabilities"`
|
||||
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
|
||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||
Status int `orm:"status" json:"status"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||
TargetField string `orm:"target_field" json:"targetField"`
|
||||
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
|
||||
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
|
||||
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
|
||||
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
|
||||
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
|
||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||
Status int `orm:"status" json:"status"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
// providerProtocolCol 列名
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
@@ -58,8 +59,7 @@ type AsynchModel struct {
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||
IsAsync *int `orm:"is_async" json:"isAsync"`
|
||||
IsStream *int `orm:"is_stream" json:"isStream"`
|
||||
CallModel int `orm:"call_model" json:"callModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
@@ -148,11 +148,10 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||||
|
||||
// SendCallbackReq 发送回调的请求体
|
||||
type SendCallbackReq struct {
|
||||
TaskId string `json:"taskId"`
|
||||
Status string `json:"status"`
|
||||
Messages map[string]any `json:"messages,omitempty"`
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
TaskId string `json:"taskId"`
|
||||
Status string `json:"status"`
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||
}
|
||||
|
||||
// SendCallback 向业务方发送回调
|
||||
@@ -165,18 +164,32 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
||||
req := SendCallbackReq{
|
||||
TaskId: composeTask.TaskId,
|
||||
Status: composeTask.Status,
|
||||
Messages: composeTask.Messages,
|
||||
ErrorMsg: composeTask.ErrorMessage,
|
||||
EpicycleId: epicycleId,
|
||||
}
|
||||
// 3. 发送 POST 请求
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp struct{}
|
||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
|
||||
composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String())
|
||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||
composeTask.TaskId, composeTask.CallbackUrl)
|
||||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadFile 从 OSS 下载文件内容
|
||||
func DownloadFile(ossURL string) ([]byte, error) {
|
||||
resp, err := http.Get(ossURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
@@ -11,13 +11,12 @@ import (
|
||||
"prompts-core/model/dto"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
//1) 构建系统提示词
|
||||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||
ir.AddSystem(systemPrompt)
|
||||
@@ -28,40 +27,37 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
// 提取 userForm 中的 prompt 作为自定义提示词
|
||||
var customPrompt string
|
||||
for _, item := range req.UserForm {
|
||||
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
|
||||
customPrompt = gconv.String(prompt)
|
||||
break
|
||||
}
|
||||
}
|
||||
// 用户消息
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
|
||||
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil || protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败")
|
||||
}
|
||||
// 如果传了自定义提示词,替换掉协议模板
|
||||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||
protocol.SystemPromptTemplate = customPrompt[0]
|
||||
protocol.SystemPromptTemplate = customPrompt[0] +
|
||||
"【核心铁律】" +
|
||||
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||||
}
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
if err != nil {
|
||||
@@ -72,9 +68,11 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
"buildType": req.BuildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuildWithRounds 构建提示词
|
||||
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: chatModel.OperatorName,
|
||||
@@ -83,23 +81,15 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
||||
if err != nil || providerProtocol == nil {
|
||||
return ""
|
||||
}
|
||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
outputJSON, //【输出结构】 %s
|
||||
)
|
||||
}
|
||||
|
||||
// buildUserFormContent 构建用户表单内容字符串
|
||||
func buildUserFormContent(userForm []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range userForm {
|
||||
builder.WriteString(fmt.Sprintf("%v\n", item))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
||||
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
return util.CountToken(fullContent, model.TokenConfig)
|
||||
}
|
||||
@@ -129,7 +119,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// buildUserFormText 构建用户表单内容字符串
|
||||
func buildUserFormText(form []map[string]any) string {
|
||||
if len(form) == 0 {
|
||||
return ""
|
||||
@@ -137,32 +126,22 @@ func buildUserFormText(form []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range form {
|
||||
for k, v := range item {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
switch val := v.(type) {
|
||||
case []any:
|
||||
// 数组类型:逐条列出
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
for i, elem := range val {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
if m, ok := elem.(map[string]any); ok {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
for mk, mv := range m {
|
||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
|
||||
}
|
||||
}
|
||||
case []map[string]any:
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
for i, m := range val {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
for mk, mv := range m {
|
||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||
builder.WriteString(fmt.Sprint(elem))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("%s:%v\n", k, v))
|
||||
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,9 +154,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
formStr := util.FormToJSON(req.Form)
|
||||
userFormStr := util.UserFormToJSON(req.UserForm)
|
||||
|
||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||||
return fmt.Sprintf(promptTpl,
|
||||
gjson.New(req.Form).MustToJsonString(),
|
||||
gjson.New(req.UserForm).MustToJsonString(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
@@ -13,8 +13,9 @@ import (
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -79,7 +80,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
|
||||
// handleBuild 通用构建处理
|
||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
@@ -89,7 +90,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
var taskReq map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||
case public.BuildTypeNode:
|
||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
case public.BuildTypeStruct:
|
||||
@@ -117,7 +118,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
RequestPayload: gconv.Map(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@@ -128,24 +129,43 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
// Callback 回调处理
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||
|
||||
// 1) 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
// 2) 处理失败
|
||||
|
||||
// 2) 读取 OSS 文件内容
|
||||
var ossContent []byte
|
||||
if req.OssFile != "" {
|
||||
ossContent, err = gateway.DownloadFile(req.OssFile)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 解析 OSS 内容为消息
|
||||
var messages map[string]any
|
||||
if len(ossContent) > 0 {
|
||||
messages, _ = gjson.New(ossContent).Map(), nil
|
||||
}
|
||||
|
||||
// 4) 处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailed(ctx, req, composeTask)
|
||||
return handleCallbackFailed(ctx, req, composeTask, messages)
|
||||
}
|
||||
// 3) 处理成功
|
||||
|
||||
// 5) 处理成功
|
||||
if req.State == 2 {
|
||||
return handleCallbackSuccess(ctx, req, composeTask)
|
||||
return handleCallbackSuccess(ctx, req, composeTask, messages)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCallbackFailed 处理回调失败
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
@@ -153,7 +173,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
ResultJson: messages,
|
||||
})
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusFailed
|
||||
@@ -164,7 +184,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
@@ -173,146 +193,125 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
// 2) 根据运营商获取协议配置
|
||||
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
// ProviderName: model.OperatorName,
|
||||
//})
|
||||
|
||||
// 2) 解析结果
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||
case public.BuildTypeStruct:
|
||||
messages = ParseStructResult(req.Messages, model.ResponseBody)
|
||||
default:
|
||||
messages = req.Messages
|
||||
// 2) 获取协议配置
|
||||
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: model.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 3) 获取历史消息 + 保存当前轮
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
var history []dto.FlatMessage
|
||||
var epicycleId int64
|
||||
|
||||
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||
// 3.1 获取历史
|
||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
})
|
||||
if h != nil {
|
||||
history = h.Messages
|
||||
}
|
||||
|
||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
// 3) 合并附加结构
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 4) 更新数据库
|
||||
// 5) 注入历史
|
||||
if len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
}
|
||||
|
||||
// 6) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
GatewayState: req.State,
|
||||
OssFile: req.OssFile,
|
||||
FileType: req.FileType,
|
||||
ResultText: req.Messages,
|
||||
ResultJson: messages,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
var epicycleId int64
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: messages,
|
||||
})
|
||||
}
|
||||
// 6) 拼接历史内容
|
||||
// 7) 回调业务方
|
||||
|
||||
// 8) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.Messages = messages
|
||||
composeTask.ResultJson = messages
|
||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseResult 解析结果
|
||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||
if responseBody == "" {
|
||||
return raw
|
||||
// InjectHistory 插入历史会话
|
||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||
if protocol == nil || len(history) == 0 {
|
||||
return roundsData
|
||||
}
|
||||
// 1) 提取第一轮的 messages
|
||||
rounds := roundsData["rounds"].([]any)
|
||||
firstRound := rounds[0].(map[string]any)
|
||||
original := firstRound["messages"].([]any)
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
if contentVal == nil {
|
||||
return raw
|
||||
}
|
||||
// 2) 按 merge_order 拼接
|
||||
result := make([]any, 0, len(original)+len(history))
|
||||
|
||||
// 已经是数组
|
||||
if arr, ok := contentVal.([]any); ok {
|
||||
rounds := gconv.Maps(arr)
|
||||
if len(rounds) > 0 {
|
||||
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// 是字符串
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
return map[string]any{"total_rounds": len(arr), "rounds": arr}
|
||||
}
|
||||
|
||||
// 尝试解析为单对象
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
|
||||
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
|
||||
}
|
||||
|
||||
return map[string]any{"content": contentStr}
|
||||
}
|
||||
|
||||
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// 如果外层已有 rounds,直接返回
|
||||
if _, ok := raw["rounds"]; ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
|
||||
var rounds []map[string]any
|
||||
|
||||
// 是字符串,尝试解析
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" || contentStr == "0" {
|
||||
rounds = append(rounds, map[string]any{responseBody: raw})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
for _, part := range protocol.MergeOrder {
|
||||
switch part {
|
||||
case "system":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "system" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
case "history":
|
||||
if gconv.Bool(protocol.Capabilities["support_history"]) {
|
||||
for _, msg := range history {
|
||||
result = append(result, map[string]any{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content, // 纯字符串,不转换
|
||||
})
|
||||
}
|
||||
}
|
||||
case "user":
|
||||
for _, m := range original {
|
||||
msg := m.(map[string]any)
|
||||
if gconv.String(msg["role"]) == "user" {
|
||||
result = append(result, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
rounds = append(rounds, map[string]any{responseBody: arr})
|
||||
return map[string]any{
|
||||
"total_rounds": len(rounds),
|
||||
"rounds": rounds,
|
||||
// 3) 角色映射
|
||||
if len(protocol.RoleMapping) > 0 {
|
||||
for _, m := range result {
|
||||
msg := m.(map[string]any)
|
||||
role := gconv.String(msg["role"])
|
||||
if mapped, ok := protocol.RoleMapping[role]; ok {
|
||||
msg["role"] = mapped
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为单个对象
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||
rounds = append(rounds, map[string]any{responseBody: parsed})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 兜底:原始字符串作为内容
|
||||
rounds = append(rounds, map[string]any{responseBody: contentStr})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
// 4) 直接修改原对象
|
||||
firstRound["messages"] = result
|
||||
return roundsData
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
@@ -323,31 +322,10 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
messages := parseMessagesForResponse(record.Messages)
|
||||
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: messages,
|
||||
Messages: record.ResultJson,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseMessagesForResponse 解析用于响应的消息
|
||||
func parseMessagesForResponse(messages any) any {
|
||||
str, ok := messages.(string)
|
||||
if !ok || str == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
||||
}
|
||||
|
||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
if skillName == "" {
|
||||
return ""
|
||||
}
|
||||
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
||||
|
||||
@@ -2,18 +2,18 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/service/gateway"
|
||||
"strings"
|
||||
|
||||
"prompts-core/dao"
|
||||
"prompts-core/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// PromptIR 统一 Prompt 中间表示
|
||||
type PromptIR struct {
|
||||
// IR 统一 Prompt 中间表示
|
||||
type IR struct {
|
||||
System []Segment `json:"system"`
|
||||
History []Segment `json:"history"`
|
||||
User []Segment `json:"user"`
|
||||
@@ -34,6 +34,7 @@ type ProviderProtocol struct {
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
@@ -43,8 +44,8 @@ type ContentMapping struct {
|
||||
}
|
||||
|
||||
// NewPromptIR 创建空 PromptIR
|
||||
func NewPromptIR() *PromptIR {
|
||||
return &PromptIR{
|
||||
func NewPromptIR() *IR {
|
||||
return &IR{
|
||||
System: make([]Segment, 0),
|
||||
History: make([]Segment, 0),
|
||||
User: make([]Segment, 0),
|
||||
@@ -52,7 +53,7 @@ func NewPromptIR() *PromptIR {
|
||||
}
|
||||
|
||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||
func (ir *PromptIR) String() string {
|
||||
func (ir *IR) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -78,7 +79,7 @@ func (ir *PromptIR) String() string {
|
||||
}
|
||||
|
||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||
func (ir *PromptIR) GetTotalContent() string {
|
||||
func (ir *IR) GetTotalContent() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -100,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
|
||||
}
|
||||
|
||||
// AddSystem 添加系统提示
|
||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
func (ir *IR) AddSystem(content string) *IR {
|
||||
if content != "" {
|
||||
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
||||
}
|
||||
@@ -108,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
}
|
||||
|
||||
// AddUser 添加用户消息
|
||||
func (ir *PromptIR) AddUser(content string) *PromptIR {
|
||||
func (ir *IR) AddUser(content string) *IR {
|
||||
if content != "" {
|
||||
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
|
||||
}
|
||||
|
||||
// AddHistory 添加历史消息
|
||||
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
func (ir *IR) AddHistory(role, content string) *IR {
|
||||
if content != "" {
|
||||
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
}
|
||||
|
||||
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
func (ir *IR) ToMessages() []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -165,21 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
||||
|
||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
p := &ProviderProtocol{
|
||||
return &ProviderProtocol{
|
||||
TargetField: e.TargetField,
|
||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||
MergeOrder: e.MergeOrder,
|
||||
RoleMapping: gconv.MapStrStr(e.RoleMapping),
|
||||
ContentMapping: ContentMapping{
|
||||
Type: gconv.String(e.ContentMapping["type"]),
|
||||
Field: gconv.String(e.ContentMapping["field"]),
|
||||
},
|
||||
RequestTemplate: e.RequestTemplate,
|
||||
Capabilities: e.Capabilities,
|
||||
}
|
||||
|
||||
// 使用通用解析方法处理各个字段
|
||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||
return p
|
||||
}
|
||||
|
||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
if ir == nil || p == nil {
|
||||
return nil, fmt.Errorf("ir and protocol are required")
|
||||
}
|
||||
@@ -191,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
|
||||
}
|
||||
|
||||
// mergeByOrder 按协议配置顺序拼接消息
|
||||
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, part := range order {
|
||||
switch part {
|
||||
case "system":
|
||||
for _, seg := range ir.System {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
case "history":
|
||||
for _, seg := range ir.History {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": seg.Role,
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
case "user":
|
||||
for _, seg := range ir.User {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
func mergeByOrder(ir *IR, order []string) []map[string]any {
|
||||
roleMap := map[string][]Segment{
|
||||
"system": ir.System,
|
||||
"history": ir.History,
|
||||
"user": ir.User,
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
for _, part := range order {
|
||||
for _, seg := range roleMap[part] {
|
||||
msg := map[string]any{"content": seg.Content}
|
||||
if part == "history" {
|
||||
msg["role"] = seg.Role
|
||||
} else {
|
||||
msg["role"] = part
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -243,29 +235,29 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
||||
return messages
|
||||
}
|
||||
|
||||
// mapContent 内容字段映射
|
||||
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
for _, msg := range messages {
|
||||
content := msg["content"]
|
||||
delete(msg, "content")
|
||||
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
msg["parts"] = []map[string]any{
|
||||
{cm.Field: content},
|
||||
}
|
||||
default:
|
||||
msg[cm.Field] = content
|
||||
}
|
||||
if cm.Field == "" || cm.Field == "content" {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
if content, ok := msg["content"]; ok {
|
||||
delete(msg, "content")
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
|
||||
default:
|
||||
messages[i][cm.Field] = content
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||
if len(p.RequestTemplate) > 0 {
|
||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||
return renderTemplate(p, messages, chatModel)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
@@ -273,20 +265,21 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
|
||||
}
|
||||
}
|
||||
|
||||
// renderTemplate 简单的 {{key}} 模板替换
|
||||
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
b, _ := json.Marshal(tmpl)
|
||||
str := string(b)
|
||||
|
||||
if chatModel != nil {
|
||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
||||
// renderTemplate 模板渲染
|
||||
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
result := make(map[string]any, len(p.RequestTemplate)+1)
|
||||
for k, v := range p.RequestTemplate {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
msgBytes, _ := json.Marshal(messages)
|
||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||
if chatModel != nil {
|
||||
result["model"] = chatModel.ModelName
|
||||
}
|
||||
result["messages"] = messages
|
||||
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal([]byte(str), &result)
|
||||
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||
result["max_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -4,139 +4,148 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/model/entity"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/model/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
const (
|
||||
redisKeyPrefix = "chat:session:%s"
|
||||
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId}
|
||||
RedisKeySessionHistory = "session:history:%d:%s:%s"
|
||||
)
|
||||
|
||||
// formatRedisKey 格式化Redis键
|
||||
func formatRedisKey(sessionId string) string {
|
||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
||||
// formatRedisKey 格式化 Redis key
|
||||
func formatRedisKey(tenantID uint64, sessionID, nodeID string) string {
|
||||
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID)
|
||||
}
|
||||
|
||||
// saveToRedis 保存会话数据到Redis
|
||||
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
key := formatRedisKey(session.SessionId)
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
||||
data := map[string]any{
|
||||
"sessionId": session.SessionId,
|
||||
"requestContent": session.RequestContent,
|
||||
"responseContent": session.ResponseContent,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
b, err := json.Marshal(data)
|
||||
// ============================================
|
||||
// 写操作
|
||||
// ============================================
|
||||
|
||||
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||||
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||||
|
||||
b, err := json.Marshal(round)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
score := float64(time.Now().UnixMilli())
|
||||
|
||||
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||||
return fmt.Errorf("ZADD失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||||
return fmt.Errorf("裁剪失败: %w", err)
|
||||
}
|
||||
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSessionHistory 删除整个 session 下所有 node 的缓存
|
||||
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||||
pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*")
|
||||
keys, err := g.Redis().Do(ctx, "KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys.Strings() {
|
||||
_, _ = g.Redis().Do(ctx, "DEL", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeRedisCommands 执行Redis命令
|
||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
||||
return fmt.Errorf("写入Redis失败: %w", err)
|
||||
// DeleteRedisMessages 批量删除指定 node 下的消息
|
||||
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
for _, msgID := range msgIDs {
|
||||
cursor := "0"
|
||||
for {
|
||||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
|
||||
break
|
||||
}
|
||||
parts := result.Strings()
|
||||
if len(parts) < 2 {
|
||||
break
|
||||
}
|
||||
cursor = parts[0]
|
||||
for _, member := range parts[1:] {
|
||||
_, _ = g.Redis().Do(ctx, "ZREM", key, member)
|
||||
}
|
||||
if cursor == "0" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(sessionId)
|
||||
// ============================================
|
||||
// 读操作
|
||||
// ============================================
|
||||
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
// GetFromRedis 从 Redis ZSET 获取会话历史
|
||||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) {
|
||||
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
||||
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||||
}
|
||||
|
||||
if result == nil || result.IsNil() {
|
||||
return []map[string]any{}, nil
|
||||
return []dto.HistoryRound{}, nil
|
||||
}
|
||||
|
||||
sessions := parseRedisSessions(ctx, result.Strings())
|
||||
|
||||
reverseSlice(sessions)
|
||||
|
||||
return sessions, nil
|
||||
return parseRounds(result.Strings()), nil
|
||||
}
|
||||
|
||||
// parseRedisSessions 解析Redis会话数据
|
||||
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
||||
var sessions []map[string]any
|
||||
// ============================================
|
||||
// 解析
|
||||
// ============================================
|
||||
|
||||
for _, str := range values {
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
||||
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
||||
func parseRounds(members []string) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(members))
|
||||
for _, member := range members {
|
||||
var round dto.HistoryRound
|
||||
if err := json.Unmarshal([]byte(member), &round); err != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, data)
|
||||
}
|
||||
|
||||
return sessions
|
||||
}
|
||||
|
||||
// reverseSlice 反转切片
|
||||
func reverseSlice(s []map[string]any) {
|
||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
|
||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
||||
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
historyData, err := getFromRedis(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
||||
}
|
||||
|
||||
if len(historyData) == 0 {
|
||||
return []map[string]any{}, nil
|
||||
}
|
||||
|
||||
return flattenHistoryMessages(historyData), nil
|
||||
}
|
||||
|
||||
// flattenHistoryMessages 扁平化历史消息
|
||||
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, round := range historyData {
|
||||
appendMessagesFromField(round, "requestContent", &messages)
|
||||
appendMessagesFromField(round, "responseContent", &messages)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// appendMessagesFromField 从指定字段追加消息
|
||||
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
||||
msgs, ok := data[field].([]interface{})
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
*messages = append(*messages, msg)
|
||||
if round.User != nil || round.Assistant != nil {
|
||||
rounds = append(rounds, round)
|
||||
}
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
|
||||
var messages []dto.FlatMessage
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
|
||||
messages = append(messages, dto.FlatMessage{
|
||||
Role: gconv.String(rounds[i].User["role"]),
|
||||
Content: gconv.String(rounds[i].User["content"]),
|
||||
})
|
||||
}
|
||||
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
|
||||
messages = append(messages, dto.FlatMessage{
|
||||
Role: gconv.String(rounds[i].Assistant["role"]),
|
||||
Content: gconv.String(rounds[i].Assistant["content"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
@@ -14,9 +15,14 @@ import (
|
||||
"prompts-core/model/entity"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 回调存储
|
||||
// ============================================
|
||||
|
||||
// Callback 会话回调
|
||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
req.Messages["role"] = "assistant"
|
||||
// 1) 更新 DB
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
ResponseContent: req.Messages,
|
||||
@@ -25,103 +31,161 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
|
||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 2) 查询完整记录
|
||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
})
|
||||
if session == nil {
|
||||
if err != nil || session == nil {
|
||||
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||
}
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
if err = saveToRedis(ctx, session); err != nil {
|
||||
|
||||
// 3) entity → HistoryRound → 写入 Redis
|
||||
round := entityToHistoryRound(session)
|
||||
round.Assistant = req.Messages
|
||||
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil {
|
||||
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||
}
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
||||
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
|
||||
return &dto.SessionCallbackRes{
|
||||
Status: true,
|
||||
SessionId: session.SessionId,
|
||||
}, nil
|
||||
|
||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
||||
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
// ============================================
|
||||
// 场景1:前端历史列表(按 creator)
|
||||
// ============================================
|
||||
|
||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
||||
if err == nil && len(redisHistory) > 0 {
|
||||
return redisHistory, nil
|
||||
// GetHistoryList 获取历史列表
|
||||
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
}, req.Page, req.Size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
|
||||
}
|
||||
rounds := sessionsToHistoryRounds(sessions)
|
||||
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 场景2:提示词拼接(按 sessionId + nodeId)
|
||||
// ============================================
|
||||
|
||||
// GetHistoryMessages 获取历史消息(Redis → DB → 异步回种)
|
||||
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
|
||||
}
|
||||
// 1) Redis
|
||||
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 {
|
||||
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
|
||||
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||
}
|
||||
|
||||
// getHistoryFromDatabase 从数据库获取历史记录
|
||||
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
||||
// 2) DB
|
||||
maxRounds := util.GetMaxRounds(ctx)
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SessionId: sessionId,
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
SessionId: req.SessionId,
|
||||
NodeId: req.NodeId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
|
||||
messages := extractMessagesFromSessions(sessions)
|
||||
|
||||
cacheSessionsToRedis(ctx, sessions)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// extractMessagesFromSessions 从会话列表中提取消息
|
||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, session := range sessions {
|
||||
appendRequestMessages(session.RequestContent, &messages)
|
||||
appendResponseMessages(session.ResponseContent, &messages)
|
||||
if len(sessions) == 0 {
|
||||
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
|
||||
}
|
||||
|
||||
return messages
|
||||
// 3) 转换 + 异步回种
|
||||
rounds := sessionsToHistoryRounds(sessions)
|
||||
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds)
|
||||
|
||||
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||
}
|
||||
|
||||
// appendRequestMessages 追加请求消息
|
||||
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
||||
reqMsgs := util.ConvertToMessages(requestContent)
|
||||
for _, m := range reqMsgs {
|
||||
role := gconv.String(m["role"])
|
||||
if role == "user" || role == "assistant" {
|
||||
*messages = append(*messages, m)
|
||||
}
|
||||
// ============================================
|
||||
// 删除
|
||||
// ============================================
|
||||
|
||||
// DeleteMessages 删除消息
|
||||
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
|
||||
if len(req.MsgIds) == 0 {
|
||||
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
|
||||
}
|
||||
|
||||
user, _ := utils.GetUserInfo(ctx)
|
||||
|
||||
// 1) 批量查询
|
||||
sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||
|
||||
// 2) 批量删 DB
|
||||
_, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||
|
||||
// 3) 按 nodeId 分组删 Redis
|
||||
for _, s := range sessions {
|
||||
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds)
|
||||
}
|
||||
return &dto.DeleteMessagesRes{Ok: true}, nil
|
||||
}
|
||||
|
||||
// DeleteSession 删除整个会话
|
||||
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
||||
// 1) 删 DB
|
||||
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("DB删除失败: %w", err)
|
||||
}
|
||||
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2) 删 Redis
|
||||
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
|
||||
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
||||
}
|
||||
|
||||
return &dto.DeleteSessionRes{Ok: true}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 转换方法(entity ↔ dto,集中管理)
|
||||
// ============================================
|
||||
|
||||
// entityToHistoryRound entity → HistoryRound
|
||||
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||
return &dto.HistoryRound{
|
||||
Id: s.Id,
|
||||
SessionId: s.SessionId,
|
||||
NodeId: s.NodeId,
|
||||
CreatedAt: gconv.String(s.CreatedAt),
|
||||
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||
User: s.RequestContent,
|
||||
Assistant: s.ResponseContent,
|
||||
}
|
||||
}
|
||||
|
||||
// appendResponseMessages 追加响应消息
|
||||
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
||||
respMsgs := util.ConvertToMessages(responseContent)
|
||||
for _, m := range respMsgs {
|
||||
if m["role"] == nil {
|
||||
m["role"] = "assistant"
|
||||
}
|
||||
*messages = append(*messages, m)
|
||||
// sessionsToHistoryRounds 批量转换
|
||||
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
||||
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
rounds = append(rounds, *entityToHistoryRound(s))
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
// cacheSessionsToRedis 将会话缓存到Redis
|
||||
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
||||
for _, session := range sessions {
|
||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
||||
|
||||
for i := range respMsgs {
|
||||
if respMsgs[i]["role"] == nil {
|
||||
respMsgs[i]["role"] = "assistant"
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
||||
_ = saveToRedis(ctx, session)
|
||||
// asyncCacheToRedis 异步缓存到 Redis
|
||||
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) {
|
||||
for i := range rounds {
|
||||
if rounds[i].User != nil || rounds[i].Assistant != nil {
|
||||
_ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user