Compare commits
25 Commits
92092575bc
...
dev未优化
| Author | SHA1 | Date | |
|---|---|---|---|
| 0d52b631b9 | |||
| c22d578e1a | |||
| df26329836 | |||
| 40abf0f606 | |||
| b69e7386e2 | |||
| 1c1db7e30c | |||
| 78114f99c7 | |||
| 9410199fbe | |||
| 1f9a2b9b5f | |||
| e1461cf0f0 | |||
| aa7804656f | |||
| 5494a0c480 | |||
| ee6677c1f8 | |||
| de70d33115 | |||
| b2cad4cac2 | |||
| 05cf1b9828 | |||
| 3fa2896fc3 | |||
| c11a9ad5c8 | |||
| 0bbaddace0 | |||
| 1bcf8f6e10 | |||
| 55eb436639 | |||
| d74559ae74 | |||
| 2548ffc7ac | |||
| 855d5b9abe | |||
| 866b97d098 |
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# 阶段1: 构建
|
||||||
|
FROM golang:alpine AS builder
|
||||||
|
|
||||||
|
RUN apk add --no-cache git ca-certificates tzdata
|
||||||
|
|
||||||
|
ENV TZ=Asia/Shanghai
|
||||||
|
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||||
|
|
||||||
|
ENV GO111MODULE=on
|
||||||
|
ENV GOPROXY=https://goproxy.cn,direct
|
||||||
|
ENV CGO_ENABLED=0
|
||||||
|
ENV GOTOOLCHAIN=auto
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN go mod download && go mod tidy
|
||||||
|
|
||||||
|
RUN go build -ldflags="-s -w" -o main ./main.go
|
||||||
|
|
||||||
|
|
||||||
|
EXPOSE 3009
|
||||||
|
|
||||||
|
CMD ["./main"]
|
||||||
@@ -2,20 +2,14 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetServerPort 从配置获取服务端口
|
// GetServerName 获取服务名称
|
||||||
func GetServerPort(ctx context.Context) string {
|
func GetServerName(ctx context.Context) string {
|
||||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||||
// address 格式如 ":3009",去掉冒号
|
|
||||||
if strings.HasPrefix(address, ":") {
|
|
||||||
return address[1:]
|
|
||||||
}
|
|
||||||
return "8080"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPrompt 获取请求模型的提示词
|
// GetModelPrompt 获取请求模型的提示词
|
||||||
@@ -28,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string {
|
|||||||
func GetBuildPrompt(ctx context.Context) string {
|
func GetBuildPrompt(ctx context.Context) string {
|
||||||
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMaxRounds 获取最大轮数配置
|
||||||
|
func GetMaxRounds(ctx context.Context) int {
|
||||||
|
return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExpireMinutes 获取过期时间配置
|
||||||
|
func GetExpireMinutes(ctx context.Context) int {
|
||||||
|
return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int()
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package util
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,151 +1,81 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseOutput 解析模型输出为 JSON 格式
|
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||||
func ParseOutput(text string) (map[string]any, error) {
|
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||||
j, err := gjson.LoadJson([]byte(text))
|
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||||
if err != nil {
|
return messages
|
||||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return j.Map(), nil
|
consult := gconv.Interfaces(req["consult"])
|
||||||
}
|
if len(consult) == 0 {
|
||||||
|
return messages
|
||||||
// ConvertToMessages 将原始数据转换为消息列表
|
|
||||||
func ConvertToMessages(raw any) []map[string]any {
|
|
||||||
if raw == nil {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||||
if err != nil {
|
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||||
return nil
|
if targetPath == "" || len(templates) == 0 {
|
||||||
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
if j.Contains("messages") {
|
msgJson := gjson.New(messages)
|
||||||
return gconv.Maps(j.Get("messages").Array())
|
|
||||||
|
// rounds 路径修正
|
||||||
|
if !msgJson.Get("rounds.0").IsNil() {
|
||||||
|
targetPath = "rounds.0." + targetPath
|
||||||
}
|
}
|
||||||
|
|
||||||
return []map[string]any{j.Map()}
|
// 遍历追加
|
||||||
}
|
for _, item := range consult {
|
||||||
|
itemJson := gjson.New(item)
|
||||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
itemType := itemJson.Get("type").String()
|
||||||
func FormToJSON(form map[string]any) string {
|
tmpl := gconv.Map(templates[itemType])
|
||||||
if form == nil {
|
if itemType == "" || len(tmpl) == 0 {
|
||||||
return "{}"
|
continue
|
||||||
}
|
|
||||||
|
|
||||||
b, _ := json.Marshal(form)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
|
||||||
func UserFormToJSON(form []map[string]any) string {
|
|
||||||
if form == nil {
|
|
||||||
return "{}"
|
|
||||||
}
|
|
||||||
|
|
||||||
b, _ := json.Marshal(form)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustMarshal 将对象序列化为 JSON 字符串,失败时返回空对象
|
|
||||||
func MustMarshal(v any) string {
|
|
||||||
b, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return "{}"
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
|
||||||
func JSONPretty(v any) string {
|
|
||||||
if gv, ok := v.(*gvar.Var); ok {
|
|
||||||
v = gconv.Map(gv.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
var tmp map[string]any
|
|
||||||
if err := gconv.Struct(v, &tmp); err != nil {
|
|
||||||
return gconv.String(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
b, _ := json.MarshalIndent(tmp, "", " ")
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
|
|
||||||
func GvarToMap(v *gvar.Var) map[string]any {
|
|
||||||
if v == nil || v.IsNil() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
// 方法1:尝试获取 map 值
|
|
||||||
if m := v.Map(); len(m) > 0 {
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// 方法2:尝试解析 JSON 字符串
|
|
||||||
str := v.String()
|
|
||||||
if str != "" && str != "<nil>" {
|
|
||||||
json.Unmarshal([]byte(str), &result)
|
|
||||||
if len(result) > 0 {
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||||
|
if attachment == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := len(msgJson.Get(targetPath).Array())
|
||||||
|
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 方法3:尝试获取 interface 再转换
|
return msgJson.Map()
|
||||||
if val := v.Val(); val != nil {
|
}
|
||||||
switch val.(type) {
|
|
||||||
|
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||||
|
typ := gconv.String(tmpl["type"])
|
||||||
|
if typ == "" || url == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body := gconv.Map(tmpl["body"])
|
||||||
|
fillEmptyInPlace(body, url)
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"type": typ,
|
||||||
|
typ: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillEmptyInPlace(m map[string]any, value string) {
|
||||||
|
for k, v := range m {
|
||||||
|
switch vv := v.(type) {
|
||||||
|
case string:
|
||||||
|
if vv == "" {
|
||||||
|
m[k] = value
|
||||||
|
}
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
return val.(map[string]any)
|
fillEmptyInPlace(vv, value)
|
||||||
default:
|
|
||||||
data, _ := json.Marshal(val)
|
|
||||||
json.Unmarshal(data, &result)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
|
||||||
func ParseJSONFieldFromGvar(source any, target any) {
|
|
||||||
if source == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := source.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
if v.IsNil() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试获取 map
|
|
||||||
if m := v.Map(); len(m) > 0 {
|
|
||||||
data, _ := json.Marshal(m)
|
|
||||||
json.Unmarshal(data, target)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试解析 JSON 字符串
|
|
||||||
str := v.String()
|
|
||||||
if str != "" && str != "<nil>" {
|
|
||||||
json.Unmarshal([]byte(str), target)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
// 其他类型走原来的逻辑
|
|
||||||
data, _ := json.Marshal(source)
|
|
||||||
json.Unmarshal(data, target)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
57
common/util/mapping.go
Normal file
57
common/util/mapping.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReverseMap 映射 payload 到 mapping
|
||||||
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||||
|
jsonObj := gjson.New("{}")
|
||||||
|
for path, defaultValue := range mapping {
|
||||||
|
val := gjson.New(payload).Get(path)
|
||||||
|
if !val.IsNil() {
|
||||||
|
_ = jsonObj.Set(path, val.Val())
|
||||||
|
} else if defaultValue != nil {
|
||||||
|
_ = jsonObj.Set(path, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonObj.Map()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||||
|
func ExtractUserText(messages map[string]any) map[string]any {
|
||||||
|
msgJson := gjson.New(messages)
|
||||||
|
|
||||||
|
msgs := msgJson.Get("rounds.0.messages")
|
||||||
|
if msgs.IsNil() {
|
||||||
|
msgs = msgJson.Get("messages")
|
||||||
|
}
|
||||||
|
var texts []string
|
||||||
|
for _, m := range msgs.Array() {
|
||||||
|
msg := gjson.New(m)
|
||||||
|
if msg.Get("role").String() != "user" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content").Val()
|
||||||
|
switch c := content.(type) {
|
||||||
|
case string:
|
||||||
|
texts = append(texts, c)
|
||||||
|
case []any:
|
||||||
|
for _, item := range c {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if t := gconv.String(m["text"]); t != "" {
|
||||||
|
texts = append(texts, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": strings.Join(texts, "\n"),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetLocalIP 获取本机有效的局域网 IPv4 地址
|
|
||||||
func GetLocalIP() string {
|
|
||||||
addrs, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
return "127.0.0.1"
|
|
||||||
}
|
|
||||||
|
|
||||||
var validIPs []string
|
|
||||||
|
|
||||||
for _, addr := range addrs {
|
|
||||||
ipnet, ok := addr.(*net.IPNet)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := ipnet.IP
|
|
||||||
|
|
||||||
if isIPValid(ip) {
|
|
||||||
validIPs = append(validIPs, ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优先返回非 169.254.x.x 的 IP
|
|
||||||
for _, ip := range validIPs {
|
|
||||||
if !strings.HasPrefix(ip, "169.254.") {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 其次返回 169.254.x.x(最后的选择)
|
|
||||||
if len(validIPs) > 0 {
|
|
||||||
return validIPs[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return "127.0.0.1"
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPValid 判断 IP 是否有效
|
|
||||||
func isIPValid(ip net.IP) bool {
|
|
||||||
// 不是 loopback (127.0.0.1)
|
|
||||||
if ip.IsLoopback() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 是 IPv4
|
|
||||||
if ip.To4() == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是链路本地地址 (169.254.0.0/16)
|
|
||||||
if ip[0] == 169 && ip[1] == 254 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是组播地址
|
|
||||||
if ip.IsMulticast() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是未指定地址 (0.0.0.0)
|
|
||||||
if ip.IsUnspecified() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalAddress 获取局域网地址(IP:端口)
|
|
||||||
func GetLocalAddress(ctx context.Context) string {
|
|
||||||
ip := GetLocalIP()
|
|
||||||
port := GetServerPort(ctx)
|
|
||||||
|
|
||||||
if port == "80" || port == "443" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
return ip + ":" + port
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSchemaFromRequest 从当前请求中获取协议(http/https)
|
|
||||||
func GetSchemaFromRequest(ctx context.Context) string {
|
|
||||||
r := g.RequestFromCtx(ctx)
|
|
||||||
if r == nil {
|
|
||||||
return "http"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. 代理场景:X-Forwarded-Proto
|
|
||||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
|
||||||
return proto
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 代理场景:X-Forwarded-Scheme
|
|
||||||
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
|
|
||||||
return proto
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. TLS 连接(直接 HTTPS)
|
|
||||||
if r.TLS != nil {
|
|
||||||
return "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 默认 HTTP(这行很重要!)
|
|
||||||
return "http" // ← 确保有这行
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口)
|
|
||||||
func GetLocalBaseURL(ctx context.Context) string {
|
|
||||||
schema := GetSchemaFromRequest(ctx)
|
|
||||||
addr := GetLocalAddress(ctx)
|
|
||||||
return schema + "://" + addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCallbackURL 获取回调地址(完整 URL)
|
|
||||||
func GetCallbackURL(ctx context.Context, path string) string {
|
|
||||||
baseURL := GetLocalBaseURL(ctx)
|
|
||||||
// 确保 path 以 / 开头
|
|
||||||
if !strings.HasPrefix(path, "/") {
|
|
||||||
path = "/" + path
|
|
||||||
}
|
|
||||||
return baseURL + path
|
|
||||||
}
|
|
||||||
38
config.yml
38
config.yml
@@ -112,41 +112,3 @@ nodePrompts: |
|
|||||||
%s
|
%s
|
||||||
上下文内容:
|
上下文内容:
|
||||||
%s
|
%s
|
||||||
|
|
||||||
#你是专业的JSON结构生成专家,必须严格遵守以下全部规则。
|
|
||||||
# 【强制规则】
|
|
||||||
# 必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致;
|
|
||||||
# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读;
|
|
||||||
# 完整读取UserForm所有字段,不得忽略任何字段;
|
|
||||||
# 如果有skill相关内容必须完整的将内容拼接到system角色描述中;
|
|
||||||
# 理解全部语义后再输出,禁止断章取义;
|
|
||||||
# UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。
|
|
||||||
# 【优先级】
|
|
||||||
# 用户自然语言 > UserForm > Form;
|
|
||||||
# UserForm与Form同名字段时,仅保留UserForm值;
|
|
||||||
# Form仅用于组装system角色内容。
|
|
||||||
# 【表单处理】
|
|
||||||
# Form:系统提示词、默认参数、基础配置 → 专属填充system角色;
|
|
||||||
# UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content;
|
|
||||||
# 自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和,用户没有相关数量必须默认1;
|
|
||||||
# 图片尺寸为空时自动填充size=1024*1024。
|
|
||||||
# 【结构铁律】
|
|
||||||
# 严格沿用固定输出结构,不增删字段或修改层级;
|
|
||||||
# messages元素必须按结构返回;
|
|
||||||
# 禁止将role对象转为字符串、禁止嵌套错乱;
|
|
||||||
# 输出纯净JSON:无多余转义符、无换行符、无额外字符;
|
|
||||||
# 所有括号、引号必须成对闭合,保证JSON合法。
|
|
||||||
# 【参数赋值】
|
|
||||||
# model固定沿用传入值;
|
|
||||||
# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值;
|
|
||||||
# history历史信息必须结合UserForm里的内容对用户描述部分进行补充;
|
|
||||||
# 从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。
|
|
||||||
# 【输出要求】
|
|
||||||
# 仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号;
|
|
||||||
# 完整合UserForm全部字段语义到user描述;
|
|
||||||
# 生成后自检JSON语法、结构、数量;错误则自动重新生成。
|
|
||||||
# 【输出结构】
|
|
||||||
# %s
|
|
||||||
# 【完整输入信息】
|
|
||||||
# %s
|
|
||||||
# 直接输出最终JSON:
|
|
||||||
@@ -9,4 +9,9 @@ const (
|
|||||||
const (
|
const (
|
||||||
BuildTypePrompt = 1 //提示词构建
|
BuildTypePrompt = 1 //提示词构建
|
||||||
BuildTypeNode = 2 //节点构建
|
BuildTypeNode = 2 //节点构建
|
||||||
|
BuildTypeStruct = 3 //结构构建
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeInference = 100 // 推理模型
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
|
|
||||||
promptService "prompts-core/service/prompt"
|
promptService "prompts-core/service/prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,36 @@
|
|||||||
|
// ============================================
|
||||||
|
// controller/session.go
|
||||||
|
// ============================================
|
||||||
|
|
||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"prompts-core/model/dto"
|
|
||||||
|
|
||||||
promptService "prompts-core/service/prompt"
|
"prompts-core/model/dto"
|
||||||
|
sessionService "prompts-core/service/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
type session struct{}
|
type session struct{}
|
||||||
|
|
||||||
// Session 提示词会话控制器
|
|
||||||
var Session = new(session)
|
var Session = new(session)
|
||||||
|
|
||||||
// SessionCallback 会话回调
|
// SessionCallback 会话回调
|
||||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||||
return promptService.SessionCallback(ctx, req)
|
return sessionService.Callback(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryList 获取历史列表(前端列表)
|
||||||
|
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
|
||||||
|
return sessionService.GetHistoryList(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessages 批量删除消息
|
||||||
|
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
|
||||||
|
return sessionService.DeleteMessages(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除整个会话
|
||||||
|
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
||||||
|
return sessionService.DeleteSession(ctx, req)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import (
|
|||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ComposeSession = &composeSessionDao{}
|
var ComposeSession = &composeSessionDao{}
|
||||||
@@ -15,13 +14,8 @@ type composeSessionDao struct{}
|
|||||||
|
|
||||||
// Insert 插入
|
// Insert 插入
|
||||||
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
|
||||||
var m = new(entity.ComposeSession)
|
|
||||||
err = gconv.Struct(req, &m)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||||
Insert(m)
|
Insert(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -69,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
|
|||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||||
OmitEmpty().
|
OmitEmpty().
|
||||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||||
|
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||||
Fields(fields).One()
|
Fields(fields).One()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -86,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
|
|||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||||
OmitEmpty().
|
OmitEmpty().
|
||||||
Where(entity.ComposeSessionCol.Id, req.Id).
|
Where(entity.ComposeSessionCol.Id, req.Id).
|
||||||
|
Where(entity.ComposeSessionCol.Creator, req.Creator).
|
||||||
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
|
||||||
Delete()
|
Delete()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -93,3 +89,36 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi
|
|||||||
}
|
}
|
||||||
return r.RowsAffected()
|
return r.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListByIds 根据 ID 列表批量查询
|
||||||
|
func (d *composeSessionDao) ListByIds(ctx context.Context, ids []int64, creator, sessionId string) (list []*entity.ComposeSession, err error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||||
|
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||||
|
Where(entity.ComposeSessionCol.Creator, creator).
|
||||||
|
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||||
|
All()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = r.Structs(&list)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteByIds 批量删除编排会话
|
||||||
|
func (d *composeSessionDao) DeleteByIds(ctx context.Context, ids []int64, creator, sessionId string) (int64, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
|
||||||
|
WhereIn(entity.ComposeSessionCol.Id, ids).
|
||||||
|
Where(entity.ComposeSessionCol.Creator, creator).
|
||||||
|
Where(entity.ComposeSessionCol.SessionId, sessionId).
|
||||||
|
Delete()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return r.RowsAffected()
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
package dao
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"prompts-core/consts/public"
|
|
||||||
"prompts-core/model/entity"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
|
||||||
)
|
|
||||||
|
|
||||||
var Model = &modelDao{}
|
|
||||||
|
|
||||||
type modelDao struct{}
|
|
||||||
|
|
||||||
// Get 获取模型
|
|
||||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
|
||||||
OmitEmpty().
|
|
||||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
|
||||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
|
||||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
|
||||||
Fields(fields).One()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = r.Struct(&m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
16
go.mod
16
go.mod
@@ -1,17 +1,12 @@
|
|||||||
module prompts-core
|
module prompts-core
|
||||||
|
|
||||||
go 1.26.0
|
go 1.26.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.com/red-future/common v0.0.19
|
gitea.redpowerfuture.com/red-future/common v0.0.23
|
||||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
|
||||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
|
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
|
||||||
github.com/gogf/gf/v2 v2.10.0
|
github.com/gogf/gf/v2 v2.10.2
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/tidwall/match v1.1.1 // indirect
|
|
||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -68,7 +63,6 @@ require (
|
|||||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||||
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/tidwall/gjson v1.19.0
|
|
||||||
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
||||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||||
|
|||||||
22
go.sum
22
go.sum
@@ -1,6 +1,6 @@
|
|||||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||||
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
|
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
|
||||||
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
@@ -77,16 +77,16 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
|
|||||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 h1:39+jbTenm7KBj4hO2C8ANAxVHpX/7OuRDs1VcGC9ylA=
|
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 h1:u8EpP24GkprogROnJ7htMov9Fc66pTP1eVYrWxiCYOs=
|
||||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0/go.mod h1:B0s0fVzn0W220E8UTpSGzrrGKsop5KcB90twBeLCiz0=
|
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2/go.mod h1:GmvM3r8GVByVMi4RD2+MCs5+CfxVXPMeT8mVDkAaAXE=
|
||||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0 h1:N/F9CuDdUZLoM1nVRqrDE/33pDZuhVxpNY4wYdeIaBs=
|
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 h1:iTQegT+lEg/wDKvj2mi3W1wrdrwFarjokf88EXVVgu4=
|
||||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0/go.mod h1:x6uoJGfZOtirIRQls8xUlYzC6f7T/eULPUa9er368X0=
|
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2/go.mod h1:ZRw3GNz5cq4uYrW4TPSVyrYWaoqzujKdWro/AOcGBaE=
|
||||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
|
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
|
||||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
|
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
|
||||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
|
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
|
||||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
|
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
|
||||||
github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs=
|
github.com/gogf/gf/v2 v2.10.2 h1:46IO0Uc8e85/FqdftJFskfDejJLBL0JBnGS5qOftUu8=
|
||||||
github.com/gogf/gf/v2 v2.10.0/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
github.com/gogf/gf/v2 v2.10.2/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
||||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
@@ -288,12 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
|
||||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
|
||||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
|
||||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
|
||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
|
||||||
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
||||||
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
||||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -7,9 +7,9 @@ import (
|
|||||||
"prompts-core/controller"
|
"prompts-core/controller"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.redpowerfuture.com/red-future/common/http"
|
||||||
"gitea.com/red-future/common/jaeger"
|
"gitea.redpowerfuture.com/red-future/common/jaeger"
|
||||||
_ "gitea.com/red-future/common/swagger"
|
_ "gitea.redpowerfuture.com/red-future/common/swagger"
|
||||||
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
|
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
|
||||||
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
|
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|||||||
@@ -6,39 +6,32 @@ type ComposeMessagesReq struct {
|
|||||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
NodeId string `p:"nodeId" json:"nodeId" dc:"节点ID"`
|
||||||
|
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
|
||||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
Form []map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||||
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||||
|
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
|
||||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||||
UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConsultItem 单个附件
|
||||||
|
type ConsultItem struct {
|
||||||
|
Type string `json:"type" dc:"附件类型:image/video/audio"`
|
||||||
|
Url string `json:"url" dc:"附件地址"`
|
||||||
|
}
|
||||||
type ComposeMessagesRes struct {
|
type ComposeMessagesRes struct {
|
||||||
TaskId string `json:"taskId" dc:"任务ID"`
|
TaskId string `json:"taskId" dc:"任务ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"`
|
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
|
||||||
*/
|
|
||||||
|
|
||||||
// MultiRoundResult 多轮返回结果
|
|
||||||
type MultiRoundResult struct {
|
|
||||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
|
||||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallbackReq struct {
|
type CallbackReq struct {
|
||||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||||
State int `json:"state" dc:"网关任务状态"`
|
State int `json:"state" dc:"网关任务状态"`
|
||||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||||
Text string `json:"text" dc:"文本结果"`
|
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackRes struct {
|
type CallbackRes struct {
|
||||||
@@ -50,11 +43,11 @@ type GetComposeTaskReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetComposeTaskRes struct {
|
type GetComposeTaskRes struct {
|
||||||
TaskId string `json:"taskId" dc:"任务ID"`
|
TaskId string `json:"taskId" dc:"任务ID"`
|
||||||
Status string `json:"status" dc:"业务状态"`
|
Status string `json:"status" dc:"业务状态"`
|
||||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||||
Messages any `json:"messages" dc:"最终消息数组"`
|
Messages map[string]any `json:"messages" dc:"最终消息数组"`
|
||||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,79 @@ package dto
|
|||||||
|
|
||||||
import "github.com/gogf/gf/v2/frame/g"
|
import "github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
type SessionCallbackReq struct {
|
// HistoryRound 一轮对话
|
||||||
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
|
type HistoryRound struct {
|
||||||
Text string `json:"text" dc:"文本结果"`
|
Id int64 `json:"id" dc:"记录ID"`
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||||
|
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||||
|
User map[string]any `json:"user" dc:"用户消息"`
|
||||||
|
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
||||||
|
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
||||||
|
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionCallbackRes struct {
|
// SessionCallbackReq 会话回调请求
|
||||||
|
type SessionCallbackReq struct {
|
||||||
|
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
||||||
|
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
|
||||||
|
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionCallbackRes 会话回调响应
|
||||||
|
type SessionCallbackRes struct {
|
||||||
|
Status bool `json:"status" dc:"状态"`
|
||||||
|
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryListReq 获取历史列表请求(前端)
|
||||||
|
type GetHistoryListReq struct {
|
||||||
|
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
|
||||||
|
Page int `json:"page" d:"1" dc:"页码"`
|
||||||
|
Size int `json:"size" d:"10" dc:"每页条数"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryListRes 获取历史列表响应
|
||||||
|
type GetHistoryListRes struct {
|
||||||
|
List []HistoryRound `json:"list" dc:"历史列表"`
|
||||||
|
Total int `json:"total" dc:"总数"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
|
||||||
|
type GetHistoryMessagesReq struct {
|
||||||
|
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
||||||
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
|
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryMessagesRes 获取历史消息响应
|
||||||
|
type GetHistoryMessagesRes struct {
|
||||||
|
Messages []FlatMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessagesReq 批量删除消息请求
|
||||||
|
type DeleteMessagesReq struct {
|
||||||
|
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
|
||||||
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
|
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteMessagesRes 批量删除消息响应
|
||||||
|
type DeleteMessagesRes struct {
|
||||||
|
Ok bool `json:"ok" dc:"是否成功"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionReq 删除整个会话请求
|
||||||
|
type DeleteSessionReq struct {
|
||||||
|
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
|
||||||
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionRes 删除整个会话响应
|
||||||
|
type DeleteSessionRes struct {
|
||||||
|
Ok bool `json:"ok" dc:"是否成功"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
package entity
|
|
||||||
|
|
||||||
import "gitea.com/red-future/common/beans"
|
|
||||||
|
|
||||||
// AsynchModel 异步模型配置
|
|
||||||
type AsynchModel struct {
|
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
|
||||||
ModelName string `orm:"model_name" json:"modelName"`
|
|
||||||
ModelType int `orm:"model_type" json:"modelType"`
|
|
||||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
|
||||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
|
||||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
|
||||||
Form any `orm:"form_json" json:"form"`
|
|
||||||
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
|
||||||
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
|
||||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
|
||||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
|
||||||
Prompt string `orm:"prompt" json:"prompt"`
|
|
||||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
|
||||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
|
||||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
|
||||||
Enabled *int `orm:"enabled" json:"enabled"`
|
|
||||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
|
||||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
|
||||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
|
||||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
|
||||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
|
||||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
|
||||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
|
||||||
Remark string `orm:"remark" json:"remark"`
|
|
||||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
|
||||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
|
||||||
TokenConfig any `orm:"token_config" json:"tokenConfig"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type asynchModelCol struct {
|
|
||||||
beans.SQLBaseCol
|
|
||||||
ModelName string
|
|
||||||
ModelType string
|
|
||||||
BaseURL string
|
|
||||||
HttpMethod string
|
|
||||||
HeadMsg string
|
|
||||||
FormJSON string
|
|
||||||
RequestMapping string
|
|
||||||
ResponseMapping string
|
|
||||||
ResponseBody string
|
|
||||||
ResponseTokenField string
|
|
||||||
Prompt string
|
|
||||||
IsPrivate string
|
|
||||||
IsChatModel string
|
|
||||||
ApiKey string
|
|
||||||
Enabled string
|
|
||||||
MaxConcurrency string
|
|
||||||
QueueLimit string
|
|
||||||
TimeoutSeconds string
|
|
||||||
ExpectedSeconds string
|
|
||||||
RetryTimes string
|
|
||||||
RetryQueueMaxSecs string
|
|
||||||
AutoCleanSeconds string
|
|
||||||
Remark string
|
|
||||||
IsOwner string
|
|
||||||
OperatorName string
|
|
||||||
TokenConfig string
|
|
||||||
}
|
|
||||||
|
|
||||||
var AsynchModelCol = asynchModelCol{
|
|
||||||
SQLBaseCol: beans.DefSQLBaseCol,
|
|
||||||
ModelName: "model_name",
|
|
||||||
ModelType: "model_type",
|
|
||||||
BaseURL: "base_url",
|
|
||||||
HttpMethod: "http_method",
|
|
||||||
HeadMsg: "head_msg",
|
|
||||||
FormJSON: "form_json",
|
|
||||||
RequestMapping: "request_mapping",
|
|
||||||
ResponseMapping: "response_mapping",
|
|
||||||
ResponseBody: "response_body",
|
|
||||||
ResponseTokenField: "response_token_field",
|
|
||||||
Prompt: "prompt",
|
|
||||||
IsPrivate: "is_private",
|
|
||||||
IsChatModel: "is_chat_model",
|
|
||||||
ApiKey: "api_key",
|
|
||||||
Enabled: "enabled",
|
|
||||||
MaxConcurrency: "max_concurrency",
|
|
||||||
QueueLimit: "queue_limit",
|
|
||||||
TimeoutSeconds: "timeout_seconds",
|
|
||||||
ExpectedSeconds: "expected_seconds",
|
|
||||||
RetryTimes: "retry_times",
|
|
||||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
|
||||||
AutoCleanSeconds: "auto_clean_seconds",
|
|
||||||
Remark: "remark",
|
|
||||||
IsOwner: "is_owner",
|
|
||||||
OperatorName: "operator_name",
|
|
||||||
TokenConfig: "token_config",
|
|
||||||
}
|
|
||||||
@@ -1,18 +1,20 @@
|
|||||||
package entity
|
package entity
|
||||||
|
|
||||||
import "gitea.com/red-future/common/beans"
|
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
|
||||||
type ComposeSession struct {
|
type ComposeSession struct {
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
SessionId string `orm:"session_id" json:"sessionId"`
|
SessionId string `orm:"session_id" json:"sessionId"`
|
||||||
RequestContent any `orm:"request_content" json:"requestContent"`
|
NodeId string `orm:"node_id" json:"nodeId"`
|
||||||
ResponseContent any `orm:"response_content" json:"responseContent"`
|
RequestContent map[string]any `orm:"request_content" json:"requestContent"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
ResponseContent map[string]any `orm:"response_content" json:"responseContent"`
|
||||||
|
Remark string `orm:"remark" json:"remark"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type composeSessionCol struct {
|
type composeSessionCol struct {
|
||||||
beans.SQLBaseCol
|
beans.SQLBaseCol
|
||||||
SessionId string
|
SessionId string
|
||||||
|
NodeId string
|
||||||
RequestContent string
|
RequestContent string
|
||||||
ResponseContent string
|
ResponseContent string
|
||||||
Remark string
|
Remark string
|
||||||
@@ -21,6 +23,7 @@ type composeSessionCol struct {
|
|||||||
var ComposeSessionCol = composeSessionCol{
|
var ComposeSessionCol = composeSessionCol{
|
||||||
SQLBaseCol: beans.DefSQLBaseCol,
|
SQLBaseCol: beans.DefSQLBaseCol,
|
||||||
SessionId: "session_id",
|
SessionId: "session_id",
|
||||||
|
NodeId: "node_id",
|
||||||
RequestContent: "request_content",
|
RequestContent: "request_content",
|
||||||
ResponseContent: "response_content",
|
ResponseContent: "response_content",
|
||||||
Remark: "remark",
|
Remark: "remark",
|
||||||
|
|||||||
@@ -1,22 +1,21 @@
|
|||||||
package entity
|
package entity
|
||||||
|
|
||||||
import "gitea.com/red-future/common/beans"
|
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
|
||||||
type ComposeTask struct {
|
type ComposeTask struct {
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
TaskId string `orm:"task_id" json:"taskId"`
|
TaskId string `orm:"task_id" json:"taskId"`
|
||||||
ModelName string `orm:"model_name" json:"modelName"`
|
ModelName string `orm:"model_name" json:"modelName"`
|
||||||
SkillName string `orm:"skill_name" json:"skillName"`
|
SkillName string `orm:"skill_name" json:"skillName"`
|
||||||
BuildType int `orm:"build_type" json:"buildType"`
|
BuildType int `orm:"build_type" json:"buildType"`
|
||||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||||
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
||||||
RequestPayload any `orm:"request_payload" json:"requestPayload"`
|
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
||||||
ResultText string `orm:"result_text" json:"resultText"`
|
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
|
||||||
Messages any `orm:"messages" json:"messages"`
|
Status string `orm:"status" json:"status"`
|
||||||
Status string `orm:"status" json:"status"`
|
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
||||||
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
FileType string `orm:"file_type" json:"fileType"`
|
||||||
FileType string `orm:"file_type" json:"fileType"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type composeTaskCol struct {
|
type composeTaskCol struct {
|
||||||
@@ -28,8 +27,7 @@ type composeTaskCol struct {
|
|||||||
CallbackUrl string
|
CallbackUrl string
|
||||||
GatewayState string
|
GatewayState string
|
||||||
RequestPayload string
|
RequestPayload string
|
||||||
ResultText string
|
ResultJson string
|
||||||
Messages string
|
|
||||||
Status string
|
Status string
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
OssFile string
|
OssFile string
|
||||||
@@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{
|
|||||||
CallbackUrl: "callback_url",
|
CallbackUrl: "callback_url",
|
||||||
GatewayState: "gateway_state",
|
GatewayState: "gateway_state",
|
||||||
RequestPayload: "request_payload",
|
RequestPayload: "request_payload",
|
||||||
ResultText: "result_text",
|
ResultJson: "result_json",
|
||||||
Messages: "messages",
|
|
||||||
Status: "status",
|
Status: "status",
|
||||||
ErrorMessage: "error_message",
|
ErrorMessage: "error_message",
|
||||||
OssFile: "oss_file",
|
OssFile: "oss_file",
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
package entity
|
package entity
|
||||||
|
|
||||||
import "gitea.com/red-future/common/beans"
|
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
|
||||||
// ProviderProtocol 模型协议映射配置
|
// ProviderProtocol 模型协议映射配置
|
||||||
type ProviderProtocol struct {
|
type ProviderProtocol struct {
|
||||||
beans.SQLBaseDO `orm:",inherit"`
|
beans.SQLBaseDO `orm:",inherit"`
|
||||||
// 业务字段
|
// 业务字段
|
||||||
ProviderName string `orm:"provider_name" json:"providerName"`
|
ProviderName string `orm:"provider_name" json:"providerName"`
|
||||||
TargetField string `orm:"target_field" json:"targetField"`
|
TargetField string `orm:"target_field" json:"targetField"`
|
||||||
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
|
MergeOrder []string `orm:"merge_order" json:"mergeOrder"`
|
||||||
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
|
RoleMapping map[string]any `orm:"role_mapping" json:"roleMapping"`
|
||||||
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
|
ContentMapping map[string]any `orm:"content_mapping" json:"contentMapping"`
|
||||||
Capabilities any `orm:"capabilities" json:"capabilities"`
|
Capabilities map[string]any `orm:"capabilities" json:"capabilities"`
|
||||||
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
|
RequestTemplate map[string]any `orm:"request_template" json:"requestTemplate"`
|
||||||
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
|
||||||
Status int `orm:"status" json:"status"`
|
Status int `orm:"status" json:"status"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
Remark string `orm:"remark" json:"remark"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// providerProtocolCol 列名
|
// providerProtocolCol 列名
|
||||||
|
|||||||
@@ -4,10 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
"strings"
|
||||||
|
|
||||||
commonHttp "gitea.com/red-future/common/http"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/os/gtime"
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
)
|
)
|
||||||
@@ -37,6 +41,70 @@ func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, err
|
|||||||
return req.TaskId, nil
|
return req.TaskId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GetModelConfigResp struct {
|
||||||
|
Model *AsynchModel `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AsynchModel struct {
|
||||||
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
|
ModelName string `orm:"model_name" json:"modelName"`
|
||||||
|
ModelType int `orm:"model_type" json:"modelType"`
|
||||||
|
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||||
|
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||||
|
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||||||
|
Form []map[string]any `orm:"form_json" json:"form"`
|
||||||
|
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||||
|
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||||
|
ResponseBody string `orm:"response_body" json:"responseBody"`
|
||||||
|
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||||
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
|
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
|
CallModel int `orm:"call_model" json:"callModel"`
|
||||||
|
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||||
|
Enabled *int `orm:"enabled" json:"enabled"`
|
||||||
|
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||||
|
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||||
|
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||||
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
|
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||||
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
|
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||||
|
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||||
|
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||||
|
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||||
|
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||||
|
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||||
|
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelConfig 获取模型配置
|
||||||
|
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
|
||||||
|
fullURL := "model-gateway/model/getModel"
|
||||||
|
// 拼接 query 参数
|
||||||
|
var params []string
|
||||||
|
if req.Creator != "" {
|
||||||
|
params = append(params, fmt.Sprintf("creator=%s", req.Creator))
|
||||||
|
}
|
||||||
|
if req.ModelName != "" {
|
||||||
|
params = append(params, fmt.Sprintf("modelName=%s", req.ModelName))
|
||||||
|
}
|
||||||
|
if req.IsChatModel != 0 {
|
||||||
|
params = append(params, fmt.Sprintf("isChatModel=%d", req.IsChatModel))
|
||||||
|
}
|
||||||
|
if len(params) > 0 {
|
||||||
|
fullURL += "?" + strings.Join(params, "&")
|
||||||
|
}
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
var resp GetModelConfigResp
|
||||||
|
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
||||||
|
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Model == nil {
|
||||||
|
return nil, fmt.Errorf("模型不存在")
|
||||||
|
}
|
||||||
|
return resp.Model, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetTaskResultRes 任务结果响应
|
// GetTaskResultRes 任务结果响应
|
||||||
type GetTaskResultRes struct {
|
type GetTaskResultRes struct {
|
||||||
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
||||||
@@ -80,78 +148,48 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
|||||||
|
|
||||||
// SendCallbackReq 发送回调的请求体
|
// SendCallbackReq 发送回调的请求体
|
||||||
type SendCallbackReq struct {
|
type SendCallbackReq struct {
|
||||||
TaskId string `json:"taskId"`
|
TaskId string `json:"taskId"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Messages *MultiRoundResult `json:"messages,omitempty"`
|
EpicycleId int64 `json:"epicycleId"`
|
||||||
EpicycleId int64 `json:"epicycleId"`
|
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
|
||||||
}
|
|
||||||
type MultiRoundResult struct {
|
|
||||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
|
||||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCallback 向业务方发送回调
|
// SendCallback 向业务方发送回调
|
||||||
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
|
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
|
||||||
// 1. 检查回调地址
|
// 1. 检查回调地址
|
||||||
if composeTask.CallbackUrl == "" {
|
if composeTask.CallbackUrl == "" {
|
||||||
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 构造请求体
|
// 2. 构造请求体
|
||||||
req := SendCallbackReq{
|
req := SendCallbackReq{
|
||||||
TaskId: composeTask.TaskId,
|
TaskId: composeTask.TaskId,
|
||||||
Status: composeTask.Status,
|
Status: composeTask.Status,
|
||||||
Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体
|
ErrorMsg: composeTask.ErrorMessage,
|
||||||
ErrorMsg: composeTask.ErrorMessage,
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
// 3. 发送 POST 请求
|
// 3. 发送 POST 请求
|
||||||
headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
var resp struct{}
|
var resp struct{}
|
||||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
|
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||||
composeTask.TaskId, composeTask.CallbackUrl, req.Messages)
|
composeTask.TaskId, composeTask.CallbackUrl)
|
||||||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||||||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl)
|
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult
|
// DownloadFile 从 OSS 下载文件内容
|
||||||
func parseMessagesToResult(messages any) *MultiRoundResult {
|
func DownloadFile(ossURL string) ([]byte, error) {
|
||||||
if messages == nil {
|
resp, err := http.Get(ossURL)
|
||||||
return nil
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result MultiRoundResult
|
return io.ReadAll(resp.Body)
|
||||||
|
|
||||||
switch v := messages.(type) {
|
|
||||||
case *MultiRoundResult:
|
|
||||||
return v
|
|
||||||
case MultiRoundResult:
|
|
||||||
return &v
|
|
||||||
case string:
|
|
||||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case []byte:
|
|
||||||
if err := json.Unmarshal(v, &result); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
case map[string]any:
|
|
||||||
// 通过 JSON 序列化再反序列化
|
|
||||||
data, _ := json.Marshal(v)
|
|
||||||
if err := json.Unmarshal(data, &result); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
data, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err = json.Unmarshal(data, &result); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &result
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,8 @@ package prompt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/consts/public"
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
@@ -12,182 +11,141 @@ import (
|
|||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// buildInferenceRequest 构建推理请求
|
|
||||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
|
||||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ir := NewPromptIR()
|
|
||||||
|
|
||||||
switch req.BuildType {
|
|
||||||
case public.BuildTypePrompt:
|
|
||||||
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
|
|
||||||
case public.BuildTypeNode:
|
|
||||||
return buildNodeTypeRequest(ctx, req, chatModel, ir)
|
|
||||||
default:
|
|
||||||
return nil, errors.New("不支持的构建类型")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
|
//1) 构建系统提示词
|
||||||
|
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
|
|
||||||
for _, msg := range history {
|
|
||||||
role := gconv.String(msg["role"])
|
|
||||||
if role != "user" && role != "assistant" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
|
||||||
}
|
|
||||||
|
|
||||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||||||
ir.AddUser(userPrompt)
|
ir.AddUser(userPrompt)
|
||||||
|
//2) 检查整体内容是否超出窗口
|
||||||
if !checkOverallContent(ir, aiModel) {
|
if !checkOverallContent(ir, aiModel) {
|
||||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||||
}
|
}
|
||||||
// 记录历史会话
|
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||||
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
|
||||||
SessionId: req.SessionId,
|
|
||||||
RequestContent: ir.User,
|
|
||||||
})
|
|
||||||
return compileToProviderRequest(ctx, ir, chatModel)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
ir.AddUser(NodeBuild(ctx, req))
|
ir.AddUser(NodeBuild(ctx, req))
|
||||||
|
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||||
|
}
|
||||||
|
|
||||||
return compileToProviderRequest(ctx, ir, chatModel)
|
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||||
|
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
|
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||||
|
ir.AddSystem(customPrompt)
|
||||||
|
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||||
|
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compileToProviderRequest 编译为 Provider 请求
|
// compileToProviderRequest 编译为 Provider 请求
|
||||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
|
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
if protocol == nil {
|
if protocol == nil {
|
||||||
return nil, errors.New("协议配置不存在")
|
return nil, fmt.Errorf("协议配置不存在或获取失败")
|
||||||
|
}
|
||||||
|
// 如果传了自定义提示词,替换掉协议模板
|
||||||
|
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||||
|
protocol.SystemPromptTemplate = customPrompt[0] +
|
||||||
|
"【核心铁律】" +
|
||||||
|
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||||||
}
|
}
|
||||||
providerReq, err := Compile(ir, protocol, chatModel)
|
providerReq, err := Compile(ir, protocol, chatModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"modelName": chatModel.ModelName,
|
"modelName": chatModel.ModelName,
|
||||||
"bizName": "prompts-core",
|
"bizName": util.GetServerName(ctx),
|
||||||
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
|
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||||
"requestPayload": providerReq,
|
"requestPayload": providerReq,
|
||||||
|
"buildType": req.BuildType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
|
// promptBuildWithRounds 构建提示词
|
||||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
|
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: model.OperatorName,
|
ProviderName: chatModel.OperatorName,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
})
|
})
|
||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||||||
outputJSON := util.JSONPretty(model.RequestMapping)
|
|
||||||
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
|
|
||||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
|
||||||
|
|
||||||
userFormContent := buildUserFormContent(req.UserForm)
|
|
||||||
formInfo := fmt.Sprintf(`
|
|
||||||
【系统表单(系统提示词/参数)】
|
|
||||||
%s
|
|
||||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
|
||||||
%s
|
|
||||||
`, util.FormToJSON(req.Form), userFormContent)
|
|
||||||
|
|
||||||
inputInfo := fmt.Sprintf(`
|
|
||||||
目标模型: %s
|
|
||||||
%s
|
|
||||||
技能名称: %s
|
|
||||||
用户文件: %v
|
|
||||||
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
|
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||||
req.ModelName,
|
outputJSON, //【输出结构】 %s
|
||||||
maxWindowSize,
|
|
||||||
availableWindow,
|
|
||||||
totalRounds,
|
|
||||||
totalRounds,
|
|
||||||
totalRounds,
|
|
||||||
outputJSON,
|
|
||||||
inputInfo,
|
|
||||||
totalRounds,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserFormContent 构建用户表单内容字符串
|
|
||||||
func buildUserFormContent(userForm []map[string]any) string {
|
|
||||||
var builder strings.Builder
|
|
||||||
for _, item := range userForm {
|
|
||||||
builder.WriteString(fmt.Sprintf("%v\n", item))
|
|
||||||
}
|
|
||||||
return builder.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkOverallContent 检查整体内容是否超出窗口
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
|
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||||||
fullContent := ir.String()
|
fullContent := ir.String()
|
||||||
return util.CountToken(fullContent, model.TokenConfig)
|
return util.CountToken(fullContent, model.TokenConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserPrompt 构建用户提示词
|
// buildUserPrompt 构建用户提示词
|
||||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||||
userFormForPayload := prepareUserFormPayload(req.UserForm)
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
|
||||||
payload := map[string]any{
|
if prompt != "" {
|
||||||
"model": req.ModelName,
|
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
|
||||||
"promptInfo": prompt,
|
|
||||||
"form": req.Form,
|
|
||||||
"userForm": userFormForPayload,
|
|
||||||
"userFiles": req.UserFiles,
|
|
||||||
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
|
|
||||||
"skills": SkillMdContent(ctx, req.SkillName),
|
|
||||||
}
|
}
|
||||||
|
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
|
||||||
return util.MustMarshal(payload)
|
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
|
||||||
|
}
|
||||||
|
if formText := buildUserFormText(req.Form); formText != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
||||||
|
}
|
||||||
|
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
||||||
|
}
|
||||||
|
if len(req.Consult) > 0 {
|
||||||
|
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
|
||||||
|
}
|
||||||
|
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareUserFormPayload 准备用户表单载荷
|
func buildUserFormText(form []map[string]any) string {
|
||||||
func prepareUserFormPayload(userForm []map[string]any) any {
|
if len(form) == 0 {
|
||||||
if len(userForm) == 0 {
|
return ""
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := userForm[0]["batch_index"]; ok {
|
|
||||||
return userForm
|
|
||||||
}
|
|
||||||
|
|
||||||
return mergeUserFormTexts(userForm)
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
|
|
||||||
func mergeUserFormTexts(userForm []map[string]any) string {
|
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for i, item := range userForm {
|
for _, item := range form {
|
||||||
text := getItemText(item)
|
for k, v := range item {
|
||||||
if i > 0 {
|
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||||
builder.WriteString("\n\n")
|
switch val := v.(type) {
|
||||||
|
case []any:
|
||||||
|
for i, elem := range val {
|
||||||
|
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||||
|
if m, ok := elem.(map[string]any); ok {
|
||||||
|
for mk, mv := range m {
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
builder.WriteString(fmt.Sprint(elem))
|
||||||
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
builder.WriteString(text)
|
|
||||||
}
|
}
|
||||||
return builder.String()
|
return strings.TrimSpace(builder.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeBuild 节点构建
|
// NodeBuild 节点构建
|
||||||
@@ -196,9 +154,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
|||||||
if promptTpl == "" {
|
if promptTpl == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
return fmt.Sprintf(promptTpl,
|
||||||
formStr := util.FormToJSON(req.Form)
|
gjson.New(req.Form).MustToJsonString(),
|
||||||
userFormStr := util.UserFormToJSON(req.UserForm)
|
gjson.New(req.UserForm).MustToJsonString(),
|
||||||
|
)
|
||||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,9 @@ package prompt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"prompts-core/service/session"
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
"gitea.com/red-future/common/utils"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
@@ -16,49 +12,55 @@ import (
|
|||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
"prompts-core/service/gateway"
|
"prompts-core/service/gateway"
|
||||||
|
|
||||||
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ComposeMessages 核心拼接提示词主流程
|
// ComposeMessages 核心拼接提示词主流程
|
||||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||||
|
// 1) 获取模型信息
|
||||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// 2) 校验用户表单
|
||||||
if err = validateUserForm(req, aiModel); err != nil {
|
if err = validateUserForm(req, aiModel); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
switch req.BuildType {
|
return handleBuild(ctx, req, chatModel, aiModel)
|
||||||
case public.BuildTypePrompt:
|
|
||||||
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
|
||||||
case public.BuildTypeNode:
|
|
||||||
return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建
|
|
||||||
default:
|
|
||||||
return nil, errors.New("BuildType 不支持")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelMessage 获取模型信息
|
// GetModelMessage 获取模型信息
|
||||||
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
userInfo, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||||
}
|
}
|
||||||
|
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
chatModel, err := getChatModel(ctx, userInfo.UserName)
|
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||||
if err != nil {
|
IsChatModel: 1,
|
||||||
return nil, nil, err
|
})
|
||||||
|
if err != nil || chatModel == nil {
|
||||||
|
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||||
}
|
}
|
||||||
|
|
||||||
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
|
aiModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
if err != nil {
|
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
|
||||||
return nil, nil, err
|
ModelName: req.ModelName,
|
||||||
|
})
|
||||||
|
if err != nil || aiModel == nil {
|
||||||
|
return nil, nil, errors.New("需要构建的模型不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
return chatModel, aiModel, nil
|
return chatModel, aiModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateUserForm 校验用户表单
|
// validateUserForm 校验用户表单
|
||||||
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
|
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
|
||||||
if len(req.UserForm) == 0 {
|
if len(req.UserForm) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -72,274 +74,244 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
|
|||||||
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试",
|
||||||
exceedTokens, availableWindow)
|
exceedTokens, availableWindow)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handlePromptBuild 处理提示词构建(BuildType=1)
|
// handleBuild 通用构建处理
|
||||||
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
// 获取历史会话
|
// 1) 处理表单分批
|
||||||
history, err := GetHistoryMessages(ctx, req.SessionId)
|
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err)
|
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||||
history = nil
|
}
|
||||||
|
|
||||||
|
// 2) 构建推理请求
|
||||||
|
ir := NewPromptIR()
|
||||||
|
var taskReq map[string]any
|
||||||
|
switch req.BuildType {
|
||||||
|
case public.BuildTypePrompt:
|
||||||
|
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||||
|
case public.BuildTypeNode:
|
||||||
|
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||||
|
case public.BuildTypeStruct:
|
||||||
|
taskReq, err = buildStructTypeRequest(ctx, req, chatModel, ir)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("不支持的构建类型")
|
||||||
}
|
}
|
||||||
// 调用推理模型
|
|
||||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
return nil, fmt.Errorf("构建推理请求失败: %w", err)
|
||||||
}
|
}
|
||||||
// 保存任务记录
|
|
||||||
if err = saveComposeTask(ctx, taskID, req); err != nil {
|
|
||||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
|
||||||
}
|
|
||||||
return &dto.ComposeMessagesRes{
|
|
||||||
TaskId: taskID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
// 3) 调用网关创建任务
|
||||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||||
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
return nil, fmt.Errorf("创建网关任务失败: %w", err)
|
||||||
|
}
|
||||||
|
if taskID == "" {
|
||||||
|
return nil, errors.New("网关未返回taskId")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := saveComposeTask(ctx, taskID, req); err != nil {
|
// 4) 保存任务记录
|
||||||
return nil, fmt.Errorf("保存任务记录失败: %w", err)
|
if _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
||||||
}
|
|
||||||
|
|
||||||
return &dto.ComposeMessagesRes{
|
|
||||||
TaskId: taskID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveComposeTask 保存组合任务记录
|
|
||||||
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
|
|
||||||
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
TaskId: taskID,
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
SkillName: req.SkillName,
|
SkillName: req.SkillName,
|
||||||
BuildType: req.BuildType,
|
BuildType: req.BuildType,
|
||||||
CallbackUrl: req.CallbackUrl,
|
CallbackUrl: req.CallbackUrl,
|
||||||
RequestPayload: util.MustMarshal(req),
|
RequestPayload: gconv.Map(req),
|
||||||
Status: public.ComposeStatusPending,
|
Status: public.ComposeStatusPending,
|
||||||
})
|
}); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
// getChatModel 获取聊天模型
|
|
||||||
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
|
|
||||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
|
||||||
IsChatModel: new(1),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if chatModel == nil {
|
|
||||||
return nil, errors.New("当前没有对话模型,请添加")
|
|
||||||
}
|
|
||||||
|
|
||||||
return chatModel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getAIModel 获取AI模型
|
|
||||||
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
|
|
||||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
|
|
||||||
ModelName: modelName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("查询AI模型失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if aiModel == nil {
|
|
||||||
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return aiModel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// callInferenceModel 调用推理模型
|
|
||||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
|
|
||||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("构建推理请求失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("创建网关任务失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if taskID == "" {
|
|
||||||
return "", errors.New("网关未返回taskId")
|
|
||||||
}
|
|
||||||
|
|
||||||
return taskID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createDefaultResult 创建默认结果
|
|
||||||
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
|
|
||||||
if data == nil {
|
|
||||||
data = make(map[string]any)
|
|
||||||
}
|
|
||||||
return &dto.MultiRoundResult{
|
|
||||||
TotalRounds: 1,
|
|
||||||
Rounds: []map[string]any{data},
|
|
||||||
}
|
}
|
||||||
|
return &dto.ComposeMessagesRes{TaskId: taskID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback 回调处理
|
// Callback 回调处理
|
||||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
|
||||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
// 1) 查询任务
|
||||||
TaskId: req.TaskId,
|
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("查询任务失败: %w", err)
|
return fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
if composeTask == nil {
|
|
||||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
// 2) 读取 OSS 文件内容
|
||||||
}
|
var ossContent []byte
|
||||||
//处理失败
|
if req.OssFile != "" {
|
||||||
if req.State == 3 {
|
ossContent, err = gateway.DownloadFile(req.OssFile)
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: public.ComposeStatusFailed,
|
|
||||||
ErrorMessage: req.ErrorMsg,
|
|
||||||
GatewayState: req.State,
|
|
||||||
OssFile: req.OssFile,
|
|
||||||
FileType: req.FileType,
|
|
||||||
ResultText: req.Text,
|
|
||||||
})
|
|
||||||
// 用更新后的值发送回调
|
|
||||||
if composeTask.CallbackUrl != "" {
|
|
||||||
failedTask := &entity.ComposeTask{
|
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: public.ComposeStatusFailed,
|
|
||||||
ErrorMessage: req.ErrorMsg,
|
|
||||||
CallbackUrl: composeTask.CallbackUrl,
|
|
||||||
Messages: composeTask.Messages,
|
|
||||||
}
|
|
||||||
gateway.SendCallback(ctx, failedTask)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
//处理成功
|
|
||||||
if req.State == 2 {
|
|
||||||
// 1. 根据 BuildType 解析结果
|
|
||||||
var messages any
|
|
||||||
switch composeTask.BuildType {
|
|
||||||
case public.BuildTypePrompt: // 提示词构建解析
|
|
||||||
messages = parsePromptResult(req.Text)
|
|
||||||
case public.BuildTypeNode: // 节点构建解析
|
|
||||||
messages = parseNodeResult(req.Text)
|
|
||||||
default:
|
|
||||||
messages = req.Text
|
|
||||||
}
|
|
||||||
// 2. 更新数据库
|
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: public.ComposeStatusSuccess,
|
|
||||||
Messages: messages,
|
|
||||||
GatewayState: req.State,
|
|
||||||
OssFile: req.OssFile,
|
|
||||||
FileType: req.FileType,
|
|
||||||
ResultText: req.Text,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
|
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 4. 发送回调给业务方
|
|
||||||
if composeTask.CallbackUrl != "" {
|
|
||||||
successTask := &entity.ComposeTask{
|
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: public.ComposeStatusSuccess,
|
|
||||||
Messages: messages,
|
|
||||||
CallbackUrl: composeTask.CallbackUrl,
|
|
||||||
}
|
|
||||||
gateway.SendCallback(ctx, successTask)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3) 解析 OSS 内容为消息
|
||||||
|
var messages map[string]any
|
||||||
|
if len(ossContent) > 0 {
|
||||||
|
messages, _ = gjson.New(ossContent).Map(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 处理失败
|
||||||
|
if req.State == 3 {
|
||||||
|
return handleCallbackFailed(ctx, req, composeTask, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 处理成功
|
||||||
|
if req.State == 2 {
|
||||||
|
return handleCallbackSuccess(ctx, req, composeTask, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCallbackFailed 处理回调失败
|
||||||
|
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||||
|
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
Status: public.ComposeStatusFailed,
|
||||||
|
ErrorMessage: req.ErrorMsg,
|
||||||
|
GatewayState: req.State,
|
||||||
|
OssFile: req.OssFile,
|
||||||
|
FileType: req.FileType,
|
||||||
|
ResultJson: messages,
|
||||||
|
})
|
||||||
|
if composeTask.CallbackUrl != "" {
|
||||||
|
composeTask.Status = public.ComposeStatusFailed
|
||||||
|
composeTask.ErrorMessage = req.ErrorMsg
|
||||||
|
_ = gateway.SendCallback(ctx, composeTask, 0)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// parsePromptResult 解析提示词构建结果
|
// handleCallbackSuccess 处理回调成功
|
||||||
func parsePromptResult(raw string) *dto.MultiRoundResult {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||||
var wrapper map[string]any
|
// 1) 获取模型配置
|
||||||
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
return createDefaultResult(map[string]any{"raw": raw})
|
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||||
|
ModelName: composeTask.ModelName,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("查询模型失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
contentStr, ok := wrapper["content"].(string)
|
// 2) 获取协议配置
|
||||||
if !ok || contentStr == "" {
|
protocol, _ := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
return createDefaultResult(wrapper)
|
ProviderName: model.OperatorName,
|
||||||
}
|
Status: 1,
|
||||||
|
})
|
||||||
|
|
||||||
// 先尝试解析为数组
|
// 3) 获取历史消息 + 保存当前轮
|
||||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
payload := composeTask.RequestPayload
|
||||||
return &dto.MultiRoundResult{
|
sessionId := gconv.String(payload["sessionId"])
|
||||||
TotalRounds: len(roundsArray),
|
nodeId := gconv.String(payload["nodeId"])
|
||||||
Rounds: roundsArray,
|
var history []dto.FlatMessage
|
||||||
|
var epicycleId int64
|
||||||
|
|
||||||
|
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||||
|
// 3.1 获取历史
|
||||||
|
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||||
|
SessionId: sessionId,
|
||||||
|
NodeId: nodeId,
|
||||||
|
})
|
||||||
|
if h != nil {
|
||||||
|
history = h.Messages
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||||
|
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||||
|
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
|
NodeId: nodeId,
|
||||||
|
SessionId: sessionId,
|
||||||
|
RequestContent: userMsg,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 再尝试解析为单个对象
|
// 4) 合并附加结构
|
||||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||||
return &dto.MultiRoundResult{
|
// 5) 注入历史
|
||||||
TotalRounds: 1,
|
if len(history) > 0 {
|
||||||
Rounds: []map[string]any{singleRound},
|
messages = InjectHistory(messages, history, protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6) 更新数据库
|
||||||
|
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
Status: public.ComposeStatusSuccess,
|
||||||
|
GatewayState: req.State,
|
||||||
|
OssFile: req.OssFile,
|
||||||
|
FileType: req.FileType,
|
||||||
|
ResultJson: messages,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8) 回调业务方
|
||||||
|
if composeTask.CallbackUrl != "" {
|
||||||
|
composeTask.Status = public.ComposeStatusSuccess
|
||||||
|
composeTask.ResultJson = messages
|
||||||
|
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectHistory 插入历史会话
|
||||||
|
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||||
|
if protocol == nil || len(history) == 0 {
|
||||||
|
return roundsData
|
||||||
|
}
|
||||||
|
// 1) 提取第一轮的 messages
|
||||||
|
rounds := roundsData["rounds"].([]any)
|
||||||
|
firstRound := rounds[0].(map[string]any)
|
||||||
|
original := firstRound["messages"].([]any)
|
||||||
|
|
||||||
|
// 2) 按 merge_order 拼接
|
||||||
|
result := make([]any, 0, len(original)+len(history))
|
||||||
|
|
||||||
|
for _, part := range protocol.MergeOrder {
|
||||||
|
switch part {
|
||||||
|
case "system":
|
||||||
|
for _, m := range original {
|
||||||
|
msg := m.(map[string]any)
|
||||||
|
if gconv.String(msg["role"]) == "system" {
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "history":
|
||||||
|
if gconv.Bool(protocol.Capabilities["support_history"]) {
|
||||||
|
for _, msg := range history {
|
||||||
|
result = append(result, map[string]any{
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": msg.Content, // 纯字符串,不转换
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "user":
|
||||||
|
for _, m := range original {
|
||||||
|
msg := m.(map[string]any)
|
||||||
|
if gconv.String(msg["role"]) == "user" {
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return createDefaultResult(map[string]any{"content": contentStr})
|
// 3) 角色映射
|
||||||
}
|
if len(protocol.RoleMapping) > 0 {
|
||||||
|
for _, m := range result {
|
||||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
msg := m.(map[string]any)
|
||||||
var arr []map[string]any
|
role := gconv.String(msg["role"])
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
|
if mapped, ok := protocol.RoleMapping[role]; ok {
|
||||||
return nil
|
msg["role"] = mapped
|
||||||
}
|
}
|
||||||
if len(arr) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return arr
|
|
||||||
}
|
|
||||||
|
|
||||||
func tryParseAsMap(jsonStr string) map[string]any {
|
|
||||||
var obj map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if len(obj) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseNodeResult 解析节点构建结果
|
|
||||||
func parseNodeResult(raw string) *dto.MultiRoundResult {
|
|
||||||
var result map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(raw), &result); err != nil {
|
|
||||||
return createDefaultResult(map[string]any{"raw": raw})
|
|
||||||
}
|
|
||||||
|
|
||||||
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
|
|
||||||
var inner map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
|
|
||||||
result = inner
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &dto.MultiRoundResult{
|
// 4) 直接修改原对象
|
||||||
TotalRounds: 1,
|
firstRound["messages"] = result
|
||||||
Rounds: []map[string]any{result},
|
return roundsData
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetComposeTask 查询任务结果
|
// GetComposeTask 查询任务结果
|
||||||
@@ -350,31 +322,10 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
if record == nil {
|
|
||||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := parseMessagesForResponse(record.Messages)
|
|
||||||
|
|
||||||
return &dto.GetComposeTaskRes{
|
return &dto.GetComposeTaskRes{
|
||||||
TaskId: record.TaskId,
|
TaskId: record.TaskId,
|
||||||
Status: record.Status,
|
Status: record.Status,
|
||||||
ErrorMessage: record.ErrorMessage,
|
ErrorMessage: record.ErrorMessage,
|
||||||
Messages: messages,
|
Messages: record.ResultJson,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseMessagesForResponse 解析用于响应的消息
|
|
||||||
func parseMessagesForResponse(messages any) any {
|
|
||||||
str, ok := messages.(string)
|
|
||||||
if !ok || str == "" {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
var parsed any
|
|
||||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"prompts-core/model/dto"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,15 +22,25 @@ const (
|
|||||||
bytesPerMB = 1024 * 1024
|
bytesPerMB = 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
|
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
|
||||||
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
|
||||||
result := make(map[string]string)
|
urls := make([]string, 0, len(consult))
|
||||||
|
for _, item := range consult {
|
||||||
|
if item.Url != "" {
|
||||||
|
urls = append(urls, item.Url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return FetchFileTextsAsString(ctx, urls)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
|
||||||
|
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
|
||||||
if len(urls) == 0 {
|
if len(urls) == 0 {
|
||||||
return result
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, rawURL := range urls {
|
for _, rawURL := range urls {
|
||||||
url := util.SanitizeURL(rawURL)
|
url := util.SanitizeURL(rawURL)
|
||||||
@@ -38,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if util.IsZipExtension(url) {
|
if util.IsZipExtension(url) {
|
||||||
mergeMap(result, fetchZipFileTexts(ctx, client, url))
|
for _, text := range fetchZipFileTexts(ctx, client, url) {
|
||||||
|
builder.WriteString(text)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
||||||
result[url] = text
|
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return builder.String()
|
||||||
}
|
|
||||||
|
|
||||||
// mergeMap 合并 map
|
|
||||||
func mergeMap(dst, src map[string]string) {
|
|
||||||
for k, v := range src {
|
|
||||||
dst[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchAndCleanFileContent 获取并清理文件内容
|
// fetchAndCleanFileContent 获取并清理文件内容
|
||||||
@@ -182,10 +189,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
|||||||
return strings.TrimSpace(string(body)), nil
|
return strings.TrimSpace(string(body)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
|
|
||||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||||
|
if skillName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,11 +206,13 @@ func SkillMdContent(ctx context.Context, skillName string) string {
|
|||||||
|
|
||||||
zipBytes, err := downloadFile(client, fullUrl, maxSize)
|
zipBytes, err := downloadFile(client, fullUrl, maxSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[SkillMd] 下载失败 url=%s err=%v", fullUrl, err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||||
if err != nil || len(mdContents) == 0 {
|
if err != nil || len(mdContents) == 0 {
|
||||||
|
g.Log().Warningf(ctx, "[SkillMd] 提取md失败 count=%d err=%v", len(mdContents), err)
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
# Prompts-Core(提示词核心服务)
|
|
||||||
|
|
||||||
> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 项目简介
|
|
||||||
|
|
||||||
**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。
|
|
||||||
|
|
||||||
### 核心价值
|
|
||||||
- **统一提示词管理**:集中化管理不同模型类型的提示词模板
|
|
||||||
- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储
|
|
||||||
- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议
|
|
||||||
- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容
|
|
||||||
- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 核心功能
|
|
||||||
|
|
||||||
### 1. 提示词构建引擎
|
|
||||||
|
|
||||||
#### 多模态支持
|
|
||||||
| 类型 | 说明 | 适用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 |
|
|
||||||
| Type 2 | 图片处理助手 | 图像生成、风格迁移等 |
|
|
||||||
| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 |
|
|
||||||
| Type 4 | 向量化处理助手 | 语义检索、知识索引等 |
|
|
||||||
| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 |
|
|
||||||
|
|
||||||
#### 构建模式
|
|
||||||
- **BuildType 1(提示词构建)**:完整流程,包含系统提示词、历史会话、用户输入的智能组装
|
|
||||||
- **BuildType 2(节点构建)**:工作流路由决策,根据上下文选择节点 ID
|
|
||||||
|
|
||||||
#### 分批处理
|
|
||||||
当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。
|
|
||||||
|
|
||||||
### 2. 会话管理系统
|
|
||||||
|
|
||||||
- **双层存储**:Redis 缓存(最近 N 轮)+ PostgreSQL 持久化
|
|
||||||
- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟)
|
|
||||||
|
|
||||||
### 3. 协议适配器
|
|
||||||
|
|
||||||
通过配置动态支持多种模型协议:
|
|
||||||
- 角色映射:system/user/assistant → 目标协议角色
|
|
||||||
- 内容字段映射:content → parts.text 等
|
|
||||||
- 消息顺序控制:灵活配置拼接顺序
|
|
||||||
- 请求模板渲染:支持占位符替换
|
|
||||||
|
|
||||||
### 4. 任务调度
|
|
||||||
|
|
||||||
- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果
|
|
||||||
- **重试机制**:可配置最大重试次数(默认 3 次)
|
|
||||||
- **超时保护**:默认 300 秒超时
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 技术架构
|
|
||||||
|
|
||||||
### 技术栈
|
|
||||||
|
|
||||||
| 组件 | 版本 | 用途 |
|
|
||||||
|------|------|------|
|
|
||||||
| Go | 1.26.0 | 编程语言 |
|
|
||||||
| GoFrame | v2.10.0 | Web 框架 |
|
|
||||||
| PostgreSQL | - | 关系型数据库 |
|
|
||||||
| Redis | - | 缓存与会话存储 |
|
|
||||||
| Consul | - | 服务注册与发现 |
|
|
||||||
| Jaeger | - | 分布式链路追踪 |
|
|
||||||
|
|
||||||
### 架构图
|
|
||||||
|
|
||||||
@@ -2,17 +2,18 @@ package prompt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/common/util"
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromptIR 统一 Prompt 中间表示
|
// IR 统一 Prompt 中间表示
|
||||||
type PromptIR struct {
|
type IR struct {
|
||||||
System []Segment `json:"system"`
|
System []Segment `json:"system"`
|
||||||
History []Segment `json:"history"`
|
History []Segment `json:"history"`
|
||||||
User []Segment `json:"user"`
|
User []Segment `json:"user"`
|
||||||
@@ -33,6 +34,7 @@ type ProviderProtocol struct {
|
|||||||
ContentMapping ContentMapping `json:"content_mapping"`
|
ContentMapping ContentMapping `json:"content_mapping"`
|
||||||
RequestTemplate map[string]any `json:"request_template"`
|
RequestTemplate map[string]any `json:"request_template"`
|
||||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||||
|
Capabilities map[string]any `json:"capabilities"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContentMapping 内容字段映射
|
// ContentMapping 内容字段映射
|
||||||
@@ -42,8 +44,8 @@ type ContentMapping struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPromptIR 创建空 PromptIR
|
// NewPromptIR 创建空 PromptIR
|
||||||
func NewPromptIR() *PromptIR {
|
func NewPromptIR() *IR {
|
||||||
return &PromptIR{
|
return &IR{
|
||||||
System: make([]Segment, 0),
|
System: make([]Segment, 0),
|
||||||
History: make([]Segment, 0),
|
History: make([]Segment, 0),
|
||||||
User: make([]Segment, 0),
|
User: make([]Segment, 0),
|
||||||
@@ -51,7 +53,7 @@ func NewPromptIR() *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||||
func (ir *PromptIR) String() string {
|
func (ir *IR) String() string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -77,7 +79,7 @@ func (ir *PromptIR) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||||
func (ir *PromptIR) GetTotalContent() string {
|
func (ir *IR) GetTotalContent() string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -99,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddSystem 添加系统提示
|
// AddSystem 添加系统提示
|
||||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
func (ir *IR) AddSystem(content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
||||||
}
|
}
|
||||||
@@ -107,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUser 添加用户消息
|
// AddUser 添加用户消息
|
||||||
func (ir *PromptIR) AddUser(content string) *PromptIR {
|
func (ir *IR) AddUser(content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
||||||
}
|
}
|
||||||
@@ -115,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHistory 添加历史消息
|
// AddHistory 添加历史消息
|
||||||
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
func (ir *IR) AddHistory(role, content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
||||||
}
|
}
|
||||||
@@ -123,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
||||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
func (ir *IR) ToMessages() []map[string]any {
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -164,21 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
|||||||
|
|
||||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||||
p := &ProviderProtocol{
|
return &ProviderProtocol{
|
||||||
TargetField: e.TargetField,
|
TargetField: e.TargetField,
|
||||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||||
|
MergeOrder: e.MergeOrder,
|
||||||
|
RoleMapping: gconv.MapStrStr(e.RoleMapping),
|
||||||
|
ContentMapping: ContentMapping{
|
||||||
|
Type: gconv.String(e.ContentMapping["type"]),
|
||||||
|
Field: gconv.String(e.ContentMapping["field"]),
|
||||||
|
},
|
||||||
|
RequestTemplate: e.RequestTemplate,
|
||||||
|
Capabilities: e.Capabilities,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用通用解析方法处理各个字段
|
|
||||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
|
||||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
|
||||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
|
||||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||||
if ir == nil || p == nil {
|
if ir == nil || p == nil {
|
||||||
return nil, fmt.Errorf("ir and protocol are required")
|
return nil, fmt.Errorf("ir and protocol are required")
|
||||||
}
|
}
|
||||||
@@ -190,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mergeByOrder 按协议配置顺序拼接消息
|
// mergeByOrder 按协议配置顺序拼接消息
|
||||||
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
func mergeByOrder(ir *IR, order []string) []map[string]any {
|
||||||
var messages []map[string]any
|
roleMap := map[string][]Segment{
|
||||||
|
"system": ir.System,
|
||||||
for _, part := range order {
|
"history": ir.History,
|
||||||
switch part {
|
"user": ir.User,
|
||||||
case "system":
|
|
||||||
for _, seg := range ir.System {
|
|
||||||
messages = append(messages, map[string]any{
|
|
||||||
"role": "system",
|
|
||||||
"content": seg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case "history":
|
|
||||||
for _, seg := range ir.History {
|
|
||||||
messages = append(messages, map[string]any{
|
|
||||||
"role": seg.Role,
|
|
||||||
"content": seg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case "user":
|
|
||||||
for _, seg := range ir.User {
|
|
||||||
messages = append(messages, map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": seg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var messages []map[string]any
|
||||||
|
for _, part := range order {
|
||||||
|
for _, seg := range roleMap[part] {
|
||||||
|
msg := map[string]any{"content": seg.Content}
|
||||||
|
if part == "history" {
|
||||||
|
msg["role"] = seg.Role
|
||||||
|
} else {
|
||||||
|
msg["role"] = part
|
||||||
|
}
|
||||||
|
messages = append(messages, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,29 +235,29 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
|||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapContent 内容字段映射
|
|
||||||
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||||
for _, msg := range messages {
|
if cm.Field == "" || cm.Field == "content" {
|
||||||
content := msg["content"]
|
return messages
|
||||||
delete(msg, "content")
|
|
||||||
|
|
||||||
switch cm.Type {
|
|
||||||
case "parts":
|
|
||||||
msg["parts"] = []map[string]any{
|
|
||||||
{cm.Field: content},
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
msg[cm.Field] = content
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
if content, ok := msg["content"]; ok {
|
||||||
|
delete(msg, "content")
|
||||||
|
switch cm.Type {
|
||||||
|
case "parts":
|
||||||
|
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
|
||||||
|
default:
|
||||||
|
messages[i][cm.Field] = content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
|
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
if len(p.RequestTemplate) > 0 {
|
if len(p.RequestTemplate) > 0 {
|
||||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
return renderTemplate(p, messages, chatModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
@@ -272,20 +265,21 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderTemplate 简单的 {{key}} 模板替换
|
// renderTemplate 模板渲染
|
||||||
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
|
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
b, _ := json.Marshal(tmpl)
|
result := make(map[string]any, len(p.RequestTemplate)+1)
|
||||||
str := string(b)
|
for k, v := range p.RequestTemplate {
|
||||||
|
result[k] = v
|
||||||
if chatModel != nil {
|
|
||||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgBytes, _ := json.Marshal(messages)
|
if chatModel != nil {
|
||||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
result["model"] = chatModel.ModelName
|
||||||
|
}
|
||||||
|
result["messages"] = messages
|
||||||
|
|
||||||
var result map[string]any
|
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||||
json.Unmarshal([]byte(str), &result)
|
result["max_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,145 +0,0 @@
|
|||||||
package prompt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
redisKeyPrefix = "chat:session:%s"
|
|
||||||
)
|
|
||||||
|
|
||||||
// saveToRedis 保存会话数据到Redis
|
|
||||||
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
|
|
||||||
key := formatRedisKey(sessionId)
|
|
||||||
|
|
||||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|
||||||
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
|
|
||||||
|
|
||||||
data := map[string]any{
|
|
||||||
"sessionId": sessionId,
|
|
||||||
"requestContent": requestMessages,
|
|
||||||
"responseContent": responseMessages,
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatRedisKey 格式化Redis键
|
|
||||||
func formatRedisKey(sessionId string) string {
|
|
||||||
return fmt.Sprintf(redisKeyPrefix, sessionId)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeRedisCommands 执行Redis命令
|
|
||||||
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
|
|
||||||
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {
|
|
||||||
return fmt.Errorf("写入Redis失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
|
||||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
|
||||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFromRedis 从Redis获取会话历史
|
|
||||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
||||||
key := formatRedisKey(sessionId)
|
|
||||||
|
|
||||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result == nil || result.IsNil() {
|
|
||||||
return []map[string]any{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions := parseRedisSessions(ctx, result.Strings())
|
|
||||||
|
|
||||||
reverseSlice(sessions)
|
|
||||||
|
|
||||||
return sessions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseRedisSessions 解析Redis会话数据
|
|
||||||
func parseRedisSessions(ctx context.Context, values []string) []map[string]any {
|
|
||||||
var sessions []map[string]any
|
|
||||||
|
|
||||||
for _, str := range values {
|
|
||||||
var data map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(str), &data); err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sessions = append(sessions, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
return sessions
|
|
||||||
}
|
|
||||||
|
|
||||||
// reverseSlice 反转切片
|
|
||||||
func reverseSlice(s []map[string]any) {
|
|
||||||
for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 {
|
|
||||||
s[i], s[j] = s[j], s[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
|
|
||||||
func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
||||||
historyData, err := getFromRedis(ctx, sessionId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(historyData) == 0 {
|
|
||||||
return []map[string]any{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return flattenHistoryMessages(historyData), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// flattenHistoryMessages 扁平化历史消息
|
|
||||||
func flattenHistoryMessages(historyData []map[string]any) []map[string]any {
|
|
||||||
var messages []map[string]any
|
|
||||||
|
|
||||||
for _, round := range historyData {
|
|
||||||
appendMessagesFromField(round, "requestContent", &messages)
|
|
||||||
appendMessagesFromField(round, "responseContent", &messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
// appendMessagesFromField 从指定字段追加消息
|
|
||||||
func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) {
|
|
||||||
msgs, ok := data[field].([]interface{})
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range msgs {
|
|
||||||
if msg, ok := m.(map[string]interface{}); ok {
|
|
||||||
*messages = append(*messages, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
package prompt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
|
|
||||||
"prompts-core/common/util"
|
|
||||||
"prompts-core/dao"
|
|
||||||
"prompts-core/model/dto"
|
|
||||||
"prompts-core/model/entity"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SessionCallback 会话回调
|
|
||||||
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
|
||||||
result, err := util.ParseOutput(req.Text)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
||||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result["role"] = "assistant"
|
|
||||||
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := getSessionById(ctx, req.EpicycleId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := saveSessionToRedis(ctx, session); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
|
||||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
|
||||||
|
|
||||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
|
||||||
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
|
||||||
|
|
||||||
return &dto.SessionCallbackRes{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateSessionResponse 更新会话响应
|
|
||||||
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
|
|
||||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
|
||||||
ResponseContent: response,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err)
|
|
||||||
return fmt.Errorf("更新数据库失败: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSessionById 根据ID获取会话
|
|
||||||
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
|
|
||||||
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err)
|
|
||||||
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveSessionToRedis 保存会话到Redis
|
|
||||||
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
|
||||||
requestMessages := util.ConvertToMessages(session.RequestContent)
|
|
||||||
responseMessages := util.ConvertToMessages(session.ResponseContent)
|
|
||||||
|
|
||||||
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
|
||||||
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
|
||||||
session.SessionId, session.Id, err)
|
|
||||||
return fmt.Errorf("Redis存储失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetHistoryMessages 获取历史信息
|
|
||||||
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
||||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|
||||||
|
|
||||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
|
||||||
if err == nil && len(redisHistory) > 0 {
|
|
||||||
return redisHistory, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getHistoryFromDatabase 从数据库获取历史记录
|
|
||||||
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
|
||||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
|
||||||
SessionId: sessionId,
|
|
||||||
}, 1, maxRounds)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := extractMessagesFromSessions(sessions)
|
|
||||||
|
|
||||||
cacheSessionsToRedis(ctx, sessions)
|
|
||||||
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractMessagesFromSessions 从会话列表中提取消息
|
|
||||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
|
||||||
var messages []map[string]any
|
|
||||||
|
|
||||||
for _, session := range sessions {
|
|
||||||
appendRequestMessages(session.RequestContent, &messages)
|
|
||||||
appendResponseMessages(session.ResponseContent, &messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
// appendRequestMessages 追加请求消息
|
|
||||||
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
|
||||||
reqMsgs := util.ConvertToMessages(requestContent)
|
|
||||||
for _, m := range reqMsgs {
|
|
||||||
role := gconv.String(m["role"])
|
|
||||||
if role == "user" || role == "assistant" {
|
|
||||||
*messages = append(*messages, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// appendResponseMessages 追加响应消息
|
|
||||||
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
|
||||||
respMsgs := util.ConvertToMessages(responseContent)
|
|
||||||
for _, m := range respMsgs {
|
|
||||||
if m["role"] == nil {
|
|
||||||
m["role"] = "assistant"
|
|
||||||
}
|
|
||||||
*messages = append(*messages, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cacheSessionsToRedis 将会话缓存到Redis
|
|
||||||
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
|
||||||
for _, session := range sessions {
|
|
||||||
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
|
||||||
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
|
||||||
|
|
||||||
for i := range respMsgs {
|
|
||||||
if respMsgs[i]["role"] == nil {
|
|
||||||
respMsgs[i]["role"] = "assistant"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
|
||||||
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -3,17 +3,17 @@ package prompt
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
||||||
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
||||||
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
||||||
return req, 1, nil
|
return req, 1, nil
|
||||||
}
|
}
|
||||||
|
|||||||
151
service/session/prompt_session_redis_service.go
Normal file
151
service/session/prompt_session_redis_service.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId}:{nodeId}
|
||||||
|
RedisKeySessionHistory = "session:history:%d:%s:%s"
|
||||||
|
)
|
||||||
|
|
||||||
|
// formatRedisKey 格式化 Redis key
|
||||||
|
func formatRedisKey(tenantID uint64, sessionID, nodeID string) string {
|
||||||
|
return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 写操作
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// SaveToRedis 保存一轮对话到 Redis ZSET
|
||||||
|
func SaveToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, round *dto.HistoryRound) error {
|
||||||
|
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||||
|
maxRounds := util.GetMaxRounds(ctx)
|
||||||
|
expireSeconds := int64(util.GetExpireMinutes(ctx) * 60)
|
||||||
|
|
||||||
|
b, err := json.Marshal(round)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
score := float64(time.Now().UnixMilli())
|
||||||
|
|
||||||
|
if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil {
|
||||||
|
return fmt.Errorf("ZADD失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil {
|
||||||
|
return fmt.Errorf("裁剪失败: %w", err)
|
||||||
|
}
|
||||||
|
if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||||
|
return fmt.Errorf("设置过期失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionHistory 删除整个 session 下所有 node 的缓存
|
||||||
|
func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error {
|
||||||
|
pattern := fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID, "*")
|
||||||
|
keys, err := g.Redis().Do(ctx, "KEYS", pattern)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, key := range keys.Strings() {
|
||||||
|
_, _ = g.Redis().Do(ctx, "DEL", key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRedisMessages 批量删除指定 node 下的消息
|
||||||
|
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID, nodeID string, msgIDs []int64) error {
|
||||||
|
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||||
|
for _, msgID := range msgIDs {
|
||||||
|
cursor := "0"
|
||||||
|
for {
|
||||||
|
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
parts := result.Strings()
|
||||||
|
if len(parts) < 2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cursor = parts[0]
|
||||||
|
for _, member := range parts[1:] {
|
||||||
|
_, _ = g.Redis().Do(ctx, "ZREM", key, member)
|
||||||
|
}
|
||||||
|
if cursor == "0" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 读操作
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// GetFromRedis 从 Redis ZSET 获取会话历史
|
||||||
|
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string) ([]dto.HistoryRound, error) {
|
||||||
|
key := formatRedisKey(tenantID, sessionID, nodeID)
|
||||||
|
maxRounds := util.GetMaxRounds(ctx)
|
||||||
|
|
||||||
|
result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ZREVRANGE失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil || result.IsNil() {
|
||||||
|
return []dto.HistoryRound{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseRounds(result.Strings()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 解析
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
func parseRounds(members []string) []dto.HistoryRound {
|
||||||
|
rounds := make([]dto.HistoryRound, 0, len(members))
|
||||||
|
for _, member := range members {
|
||||||
|
var round dto.HistoryRound
|
||||||
|
if err := json.Unmarshal([]byte(member), &round); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if round.User != nil || round.Assistant != nil {
|
||||||
|
rounds = append(rounds, round)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rounds
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
|
||||||
|
var messages []dto.FlatMessage
|
||||||
|
for i := len(rounds) - 1; i >= 0; i-- {
|
||||||
|
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
|
||||||
|
messages = append(messages, dto.FlatMessage{
|
||||||
|
Role: gconv.String(rounds[i].User["role"]),
|
||||||
|
Content: gconv.String(rounds[i].User["content"]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
|
||||||
|
messages = append(messages, dto.FlatMessage{
|
||||||
|
Role: gconv.String(rounds[i].Assistant["role"]),
|
||||||
|
Content: gconv.String(rounds[i].Assistant["content"]),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messages
|
||||||
|
}
|
||||||
191
service/session/prompt_session_service.go
Normal file
191
service/session/prompt_session_service.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
|
||||||
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/dao"
|
||||||
|
"prompts-core/model/dto"
|
||||||
|
"prompts-core/model/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 回调存储
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// Callback 会话回调
|
||||||
|
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||||
|
req.Messages["role"] = "assistant"
|
||||||
|
// 1) 更新 DB
|
||||||
|
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||||
|
ResponseContent: req.Messages,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
||||||
|
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 查询完整记录
|
||||||
|
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||||
|
})
|
||||||
|
if err != nil || session == nil {
|
||||||
|
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) entity → HistoryRound → 写入 Redis
|
||||||
|
round := entityToHistoryRound(session)
|
||||||
|
round.Assistant = req.Messages
|
||||||
|
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, session.NodeId, round); err != nil {
|
||||||
|
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
||||||
|
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 场景1:前端历史列表(按 creator)
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// GetHistoryList 获取历史列表
|
||||||
|
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
}, req.Page, req.Size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
|
||||||
|
}
|
||||||
|
rounds := sessionsToHistoryRounds(sessions)
|
||||||
|
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 场景2:提示词拼接(按 sessionId + nodeId)
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// GetHistoryMessages 获取历史消息(Redis → DB → 异步回种)
|
||||||
|
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) Redis
|
||||||
|
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId, req.NodeId); err == nil && len(rounds) > 0 {
|
||||||
|
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
|
||||||
|
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) DB
|
||||||
|
maxRounds := util.GetMaxRounds(ctx)
|
||||||
|
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
SessionId: req.SessionId,
|
||||||
|
NodeId: req.NodeId,
|
||||||
|
}, 1, maxRounds)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||||
|
}
|
||||||
|
if len(sessions) == 0 {
|
||||||
|
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 转换 + 异步回种
|
||||||
|
rounds := sessionsToHistoryRounds(sessions)
|
||||||
|
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, req.NodeId, rounds)
|
||||||
|
|
||||||
|
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 删除
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// DeleteMessages 删除消息
|
||||||
|
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
|
||||||
|
if len(req.MsgIds) == 0 {
|
||||||
|
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, _ := utils.GetUserInfo(ctx)
|
||||||
|
|
||||||
|
// 1) 批量查询
|
||||||
|
sessions, _ := dao.ComposeSession.ListByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||||
|
|
||||||
|
// 2) 批量删 DB
|
||||||
|
_, _ = dao.ComposeSession.DeleteByIds(ctx, req.MsgIds, user.UserName, req.SessionId)
|
||||||
|
|
||||||
|
// 3) 按 nodeId 分组删 Redis
|
||||||
|
for _, s := range sessions {
|
||||||
|
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, s.NodeId, req.MsgIds)
|
||||||
|
}
|
||||||
|
return &dto.DeleteMessagesRes{Ok: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除整个会话
|
||||||
|
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
||||||
|
// 1) 删 DB
|
||||||
|
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
|
||||||
|
SessionId: req.SessionId,
|
||||||
|
}); err != nil {
|
||||||
|
return nil, fmt.Errorf("DB删除失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 2) 删 Redis
|
||||||
|
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dto.DeleteSessionRes{Ok: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 转换方法(entity ↔ dto,集中管理)
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// entityToHistoryRound entity → HistoryRound
|
||||||
|
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||||
|
return &dto.HistoryRound{
|
||||||
|
Id: s.Id,
|
||||||
|
SessionId: s.SessionId,
|
||||||
|
NodeId: s.NodeId,
|
||||||
|
CreatedAt: gconv.String(s.CreatedAt),
|
||||||
|
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||||
|
User: s.RequestContent,
|
||||||
|
Assistant: s.ResponseContent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionsToHistoryRounds 批量转换
|
||||||
|
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
||||||
|
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
||||||
|
for _, s := range sessions {
|
||||||
|
rounds = append(rounds, *entityToHistoryRound(s))
|
||||||
|
}
|
||||||
|
return rounds
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncCacheToRedis 异步缓存到 Redis
|
||||||
|
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID, nodeID string, rounds []dto.HistoryRound) {
|
||||||
|
for i := range rounds {
|
||||||
|
if rounds[i].User != nil || rounds[i].Assistant != nil {
|
||||||
|
_ = SaveToRedis(ctx, tenantID, sessionID, nodeID, &rounds[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user