refactor(service): 重构服务代码结构并更新配置

This commit is contained in:
2026-05-18 19:19:17 +08:00
parent 5f98e52b34
commit c49144794d
35 changed files with 1281 additions and 1162 deletions

18
common/util/config.go Normal file
View File

@@ -0,0 +1,18 @@
package util
import (
"context"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
}
// GetBuildPrompt 获取构建提示词
func GetBuildPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
}

88
common/util/files.go Normal file
View File

@@ -0,0 +1,88 @@
package util
import (
"path/filepath"
"regexp"
"strings"
)
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
var AllowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
// BannedExtensions 禁止的文件扩展名
var BannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
// SanitizeURL 清洗 URL 字符串
func SanitizeURL(raw string) string {
s := strings.TrimSpace(raw)
s = strings.Trim(s, "`\"")
return s
}
// CleanSymbols 清洗文本中的控制字符和多余空行
func CleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
// IsBannedExtension 判断是否为禁止的文件扩展名
func IsBannedExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return BannedExtensions[ext]
}
// IsZipExtension 判断是否为 zip 文件
func IsZipExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext == ".zip"
}
// IsReadableContentType 判断是否为可读的文本类型
func IsReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range AllowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}

View File

@@ -1,4 +1,4 @@
package service package util
import ( import (
"context" "context"
@@ -7,9 +7,8 @@ import (
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
) )
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用 // AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。 func AsyncCtx(ctx context.Context) context.Context {
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx) asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil { if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" { if token := r.Header.Get("Authorization"); token != "" {
@@ -25,8 +24,8 @@ func asyncCtx(ctx context.Context) context.Context {
return asyncCtx return asyncCtx
} }
// forwardHeaders 透传调用链路中必须的头信息优先使用 ctx 里固化的 token / xUserInfo // ForwardHeaders 透传调用链路的头信息优先使用 ctx 中的固化值
func forwardHeaders(ctx context.Context) map[string]string { func ForwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string) headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" { if token, ok := ctx.Value("token").(string); ok && token != "" {
@@ -36,7 +35,7 @@ func forwardHeaders(ctx context.Context) map[string]string {
headers["X-User-Info"] = x headers["X-User-Info"] = x
} }
// 兜底:从请求头 // 兜底:从请求头获取
if r := g.RequestFromCtx(ctx); r != nil { if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" { if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" { if token := r.Header.Get("Authorization"); token != "" {

103
common/util/json.go Normal file
View File

@@ -0,0 +1,103 @@
package util
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ParseOutput 解析模型输出为 JSON 格式
func ParseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
// 如果有 messages 字段,直接返回
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
// 否则当成单条 message
return []map[string]any{
j.Map(),
}
}
// IsMessageValid 校验推理结果是否合法
func IsMessageValid(message map[string]any) bool {
if message == nil {
return false
}
return true
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshal 将对象序列化为 JSON 字符串,失败时返回空对象
func MustMarshal(v any) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}
// ParseJSONField 解析 JSON 字段
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
// 处理 *gvar.Var 类型
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
// 统一转 map 再美化
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}

View File

@@ -26,17 +26,38 @@ database:
updatedAt: "updated_at" # (可选)自动更新时间字段名称 updatedAt: "updated_at" # (可选)自动更新时间字段名称
deletedAt: "deleted_at" # (可选)软删除时间字段名称 deletedAt: "deleted_at" # (可选)软删除时间字段名称
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效 timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
model_gateway:
- type: "pgsql"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
name: "model-gateway"
prefix: ""
role: "master"
debug: true
dryRun: false
charset: "utf8"
timezone: "Asia/Shanghai"
maxIdle: 5
maxOpen: 20
maxLifetime: "30s"
maxIdleConnTime: "30s"
createdAt: "created_at"
updatedAt: "updated_at"
deletedAt: "deleted_at"
timeMaintainDisabled: false
redis: redis:
default: default:
address: 116.204.74.41:6379 address: 192.168.3.30:6379
db: 0 db: 0
consul: consul:
address: 116.204.74.41:8500 address: 192.168.3.30:8500
jaeger: jaeger:
addr: 116.204.74.41:4318 addr: 192.168.3.30:4318
task: task:
waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒) waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒)

View File

@@ -1,8 +1,12 @@
package public package public
const ( const (
TableNameModel = "asynch_models" // 模型表 DbNameModelGateway = "model_gateway" //数据库名称
TableNamePromptConfig = "prompts_model_prompt" // 模型提示词配置表prompts-core )
TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表
TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表 const (
TableNameModel = "asynch_models" // 模型表
TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表
TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表
TableNameProviderProtocol = "prompts_provider_protocol"
) )

View File

@@ -1,20 +0,0 @@
package controller
import (
"context"
"prompts-core/model/dto"
"prompts-core/service"
"gitea.com/red-future/common/beans"
)
type session struct{}
// Prompt 提示词配置控制器
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
return service.Session.SessionCallback(ctx, req)
}

View File

@@ -0,0 +1,29 @@
package prompt
import (
"context"
promptDto "prompts-core/model/dto/prompt"
promptService "prompts-core/service/prompt"
)
type prompt struct{}
// Prompt 提示词配置控制器
var Prompt = new(prompt)
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) {
return promptService.ComposeMessages(ctx, req)
}
// Callback model-gateway 提示词回调
func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) {
err = promptService.Callback(ctx, req)
return
}
// GetComposeTask 查询拼接任务结果
func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId)
}

View File

@@ -0,0 +1,18 @@
package prompt
import (
"context"
promptDto "prompts-core/model/dto/prompt"
promptService "prompts-core/service/prompt"
)
type session struct{}
// Session 提示词会话控制器
var Session = new(session)
// SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) {
return promptService.SessionCallback(ctx, req)
}

View File

@@ -1,69 +0,0 @@
package controller
import (
"context"
"prompts-core/model/dto"
"prompts-core/service"
"gitea.com/red-future/common/beans"
)
type prompt struct{}
// Prompt 提示词配置控制器
var Prompt = new(prompt)
// ComposeMessages 调用 model-gateway 异步任务并同步等待结果,
func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) {
return service.Prompt.ComposeMessages(ctx, req)
}
// ComposeMessagesCallback model-gateway 提示词回调
func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Callback(ctx, req)
return
}
// GetComposeTask 查询拼接任务结果
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return service.Prompt.GetComposeTask(ctx, req.TaskId)
}
// CreatePrompt 添加配置(默认启用)
func (c *prompt) CreatePrompt(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
return service.Prompt.Create(ctx, req)
}
// UpdatePrompt 更新配置
func (c *prompt) UpdatePrompt(ctx context.Context, req *dto.UpdatePromptReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Update(ctx, req)
return
}
// DeletePrompt 删除配置
func (c *prompt) DeletePrompt(ctx context.Context, req *dto.DeletePromptReq) (res *beans.ResponseEmpty, err error) {
err = service.Prompt.Delete(ctx, req.ID)
return
}
// GetPrompt 获取配置详情
func (c *prompt) GetPrompt(ctx context.Context, req *dto.GetPromptReq) (res *dto.GetPromptRes, err error) {
m, err := service.Prompt.Get(ctx, req.ID)
if err != nil {
return nil, err
}
return &dto.GetPromptRes{Prompt: m}, nil
}
// ListPrompt 配置列表
func (c *prompt) ListPrompt(ctx context.Context, req *dto.ListPromptReq) (res *dto.ListPromptRes, err error) {
list, total, err := service.Prompt.List(ctx, int(req.Page.PageNum), int(req.Page.PageSize), req.ModelTypeId, req.ModelType)
if err != nil {
return nil, err
}
return &dto.ListPromptRes{
List: list,
Total: total,
}, nil
}

View File

@@ -2,82 +2,75 @@ package dao
import ( import (
"context" "context"
"prompts-core/consts/public" "prompts-core/consts/public"
"prompts-core/model/entity" "prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
) )
var ComposeSession = &composeSessionDao{} var ComposeSession = &composeSessionDao{}
type composeSessionDao struct{} type composeSessionDao struct{}
func (d *composeSessionDao) Insert(ctx context.Context, m *entity.ComposeSession) (id int64, err error) { // Insert 插入
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).Data(m).Insert() func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) {
var m = new(entity.ComposeTask)
err = gconv.Struct(req, &m)
if err != nil { if err != nil {
return 0, err return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Insert(m)
if err != nil {
return
} }
return r.LastInsertId() return r.LastInsertId()
} }
func (d *composeSessionDao) Update(ctx context.Context, m *entity.ComposeSession) (rows int64, err error) { // Update 更新
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). func (d *composeSessionDao) Update(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
Where(entity.ComposeSessionCol.Id, m.Id). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Data(m).
OmitEmpty(). OmitEmpty().
Data(&req).
Where(entity.ComposeSessionCol.Id, req.Id).
Update() Update()
if err != nil { if err != nil {
return 0, err return
} }
return r.RowsAffected() return r.RowsAffected()
} }
func (d *composeSessionDao) List(ctx context.Context, page, size int, where map[string]any) (list []*entity.ComposeSession, total int, err error) { // List 查询编排会话列表
model := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession, page, size int, fields ...string) (list []*entity.ComposeSession, total int, err error) {
Where("deleted_at IS NULL") if page <= 0 {
page = 1
// 动态拼接查询条件
for k, v := range where {
model = model.Where(k, v)
} }
if size <= 0 {
// 查询总数 size = 10
total, err = model.Count() }
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Fields(fields).
OmitEmpty()
model.Where(entity.ComposeSessionCol.Creator, req.Creator)
model.Where(entity.ComposeSessionCol.SessionId, req.SessionId)
model.OrderDesc(entity.ComposeSessionCol.CreatedAt)
model.Page(page, size)
r, total, err := model.AllAndCount(false)
if err != nil { if err != nil {
return nil, 0, err return
} }
err = r.Structs(&list)
// 分页查询
err = model.Order("created_at DESC").
Page(page, size).
Scan(&list)
return return
} }
func (d *composeSessionDao) GetListBySessionId(ctx context.Context, sessionId string, limit int) ([]*entity.ComposeSession, error) { // Get 查询编排会话
var sessions []*entity.ComposeSession func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, fields ...string) (m *entity.ComposeSession, err error) {
err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.SessionId, sessionId). OmitEmpty().
WhereNull(entity.ComposeSessionCol.DeletedAt). Where(entity.ComposeSessionCol.Id, req.Id).
OrderDesc(entity.ComposeSessionCol.Id). Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Limit(limit). Fields(fields).One()
Scan(&sessions)
if err != nil {
return nil, err
}
// 反转成时间正序
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.ComposeSession, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.Id, Id).
One()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -88,29 +81,15 @@ func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.Co
return return
} }
func (d *composeSessionDao) GetBySessionId(ctx context.Context, sessionId string) (m *entity.ComposeSession, err error) { // Delete 软删除编排会话
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
Where(entity.ComposeSessionCol.SessionId, sessionId). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
One() OmitEmpty().
Where(entity.ComposeSessionCol.Id, req.Id).
Where(entity.ComposeSessionCol.SessionId, req.SessionId).
Delete()
if err != nil { if err != nil {
return nil, err return
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *composeSessionDao) DeleteBySessionId(ctx context.Context, sessionId string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).
Where(entity.ComposeSessionCol.SessionId, sessionId).
Data(map[string]any{
entity.ComposeSessionCol.DeletedAt: "NOW()",
}).
Update()
if err != nil {
return 0, err
} }
return r.RowsAffected() return r.RowsAffected()
} }

View File

@@ -2,47 +2,54 @@ package dao
import ( import (
"context" "context"
"prompts-core/consts/public" "prompts-core/consts/public"
"prompts-core/model/entity" "prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
) )
var ComposeTask = &composeTaskDao{} var ComposeTask = &composeTaskDao{}
type composeTaskDao struct{} type composeTaskDao struct{}
func (d *composeTaskDao) Insert(ctx context.Context, m *entity.ComposeTask) (id int64, err error) { // Insert 插入
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).Data(m).Insert() func (d *composeTaskDao) Insert(ctx context.Context, req *entity.ComposeTask) (id int64, err error) {
var m = new(entity.ComposeTask)
err = gconv.Struct(req, &m)
if err != nil { if err != nil {
return 0, err return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
Insert(m)
if err != nil {
return
} }
return r.LastInsertId() return r.LastInsertId()
} }
func (d *composeTaskDao) GetByTaskId(ctx context.Context, taskId string) (m *entity.ComposeTask, err error) { // Get 获取
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask). func (d *composeTaskDao) Get(ctx context.Context, req *entity.ComposeTask, fields ...string) (m *entity.ComposeTask, err error) {
Where(entity.ComposeTaskCol.TaskId, taskId). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
One() OmitEmpty().
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
Fields(fields).One()
if err != nil { if err != nil {
return nil, err return
}
if r.IsEmpty() {
return nil, nil
} }
err = r.Struct(&m) err = r.Struct(&m)
return return
} }
func (d *composeTaskDao) UpdateByTaskId(ctx context.Context, taskId string, data map[string]any) (rows int64, err error) { // Update 更新
data[entity.ComposeTaskCol.Updater] = "" func (d *composeTaskDao) Update(ctx context.Context, req *entity.ComposeTask) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask).
Where(entity.ComposeTaskCol.TaskId, taskId). OmitEmpty().
Data(data). Data(&req).
Where(entity.ComposeTaskCol.TaskId, req.TaskId).
Update() Update()
if err != nil { if err != nil {
return 0, err return
} }
return r.RowsAffected() return r.RowsAffected()
} }

View File

@@ -2,62 +2,27 @@ package dao
import ( import (
"context" "context"
"fmt"
"prompts-core/consts/public" "prompts-core/consts/public"
"prompts-core/model/entity" "prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
) )
var Model = &modelDao{} var Model = &modelDao{}
type modelDao struct{} type modelDao struct{}
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) { // Get 获取模型
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
Where(entity.AsynchModelCol.ModelName, modelName). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
One() OmitEmpty().
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil { if err != nil {
return nil, err return
}
if r.IsEmpty() {
return nil, nil
} }
err = r.Struct(&m) err = r.Struct(&m)
return return
} }
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.IsChatModel, 1).
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
// GetBySuperAdmin 查询超级管理员tenant_id=1的模型
func (d *modelDao) GetBySuperAdmin(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
sql := fmt.Sprintf("SELECT * FROM %s WHERE model_name = ? AND tenant_id = 1 AND deleted_at IS NULL LIMIT 1", public.TableNameModel)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, modelName)
if err != nil {
return nil, err
}
if len(r) == 0 {
return nil, nil
}
err = r[0].Struct(&m)
return
}

View File

@@ -1,97 +0,0 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var Prompt = &promptDao{}
type promptDao struct{}
func (d *promptDao) Insert(ctx context.Context, m *entity.PromptConfig) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Data(m).Insert()
if err != nil {
return 0, err
}
return r.LastInsertId()
}
func (d *promptDao) UpdateByID(ctx context.Context, id int64, data map[string]any) (rows int64, err error) {
// 触发 gfdb 的 updateHook 自动填充 updater需要显式带 updater 字段
data[entity.PromptConfigCol.Updater] = ""
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
Data(data).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *promptDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
func (d *promptDao) GetByID(ctx context.Context, id int64) (m *entity.PromptConfig, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where(entity.PromptConfigCol.Id, id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *promptDao) GetLatestEnabledByModelTypeID(ctx context.Context, modelTypeID int) (m *entity.PromptConfig, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).
Where("deleted_at IS NULL").
Where(entity.PromptConfigCol.ModelTypeId, modelTypeID).
Where(entity.PromptConfigCol.Enabled, 1).
OrderDesc(entity.PromptConfigCol.CreatedAt).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *promptDao) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
model := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Where("deleted_at IS NULL").OrderDesc(entity.PromptConfigCol.CreatedAt)
if modelTypeID != nil && *modelTypeID > 0 {
model = model.Where(entity.PromptConfigCol.ModelTypeId, *modelTypeID)
}
if modelTypeLike != "" {
model = model.WhereLike(entity.PromptConfigCol.ModelType, "%"+modelTypeLike+"%")
}
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -0,0 +1,91 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
)
var ProviderProtocol = &providerProtocolDao{}
type providerProtocolDao struct{}
// Insert 新增协议配置
func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert()
if err != nil {
return 0, err
}
return r.LastInsertId()
}
// Get 查询协议配置
func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
NoTenantId(ctx).
OmitEmpty().
Where(entity.ProviderProtocolCol.Id, req.Id).
Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询
Where(entity.ProviderProtocolCol.Status, 1).
Fields(fields).One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&res)
return
}
// List 列表查询
func (d *providerProtocolDao) List(ctx context.Context, req *entity.ProviderProtocol, page, size int, fields ...string) (list []*entity.ProviderProtocol, total int, err error) {
if page <= 0 {
page = 1
}
if size <= 0 {
size = 10
}
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Fields(fields).
OmitEmpty()
model.Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName)
model.Where(entity.ProviderProtocolCol.Status, req.Status)
model.OrderDesc(entity.ProviderProtocolCol.CreatedAt)
model.Page(page, size)
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&list)
return
}
// Update 更新协议配置
func (d *providerProtocolDao) Update(ctx context.Context, req *entity.ProviderProtocol) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
OmitEmpty().
Where(entity.ProviderProtocolCol.Id, req.Id).
Data(req).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// Delete 软删除协议配置
func (d *providerProtocolDao) Delete(ctx context.Context, id int64) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).
Where(entity.ProviderProtocolCol.Id, id).
Data(map[string]any{
entity.ProviderProtocolCol.DeletedAt: "NOW()",
}).
Update()
if err != nil {
return 0, err
}
return r.RowsAffected()
}

10
main.go
View File

@@ -4,10 +4,9 @@ import (
"context" "context"
"os" "os"
"os/signal" "os/signal"
"prompts-core/controller/prompt"
"syscall" "syscall"
"prompts-core/controller"
"gitea.com/red-future/common/http" "gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger" "gitea.com/red-future/common/jaeger"
_ "gitea.com/red-future/common/swagger" _ "gitea.com/red-future/common/swagger"
@@ -20,14 +19,13 @@ func main() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
defer jaeger.ShutDown(ctx) defer jaeger.ShutDown(ctx)
// 注册路由 // 注册路由
http.RouteRegister([]interface{}{ http.RouteRegister([]interface{}{
controller.Prompt, prompt.Prompt,
controller.Session, prompt.Session,
}) })
// 监听退出信号,确保 Ctrl+C 能完整退出并关闭 http server // 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server
quit := make(chan os.Signal, 1) quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM) signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
<-quit <-quit

View File

@@ -1,12 +1,7 @@
package dto package prompt
import "github.com/gogf/gf/v2/frame/g" import "github.com/gogf/gf/v2/frame/g"
type Message struct {
Role string `json:"role" dc:"角色system/user/assistant"`
Content any `json:"content" dc:"消息内容"`
}
type ComposeMessagesReq struct { type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"` g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
@@ -35,6 +30,9 @@ type CallbackReq struct {
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
} }
type CallbackRes struct {
}
type GetComposeTaskReq struct { type GetComposeTaskReq struct {
g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"` g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"`
TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"` TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`

View File

@@ -1,4 +1,4 @@
package dto package prompt
import "github.com/gogf/gf/v2/frame/g" import "github.com/gogf/gf/v2/frame/g"
@@ -7,3 +7,6 @@ type SessionCallbackReq struct {
Text string `json:"text" dc:"文本结果"` Text string `json:"text" dc:"文本结果"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
} }
type SessionCallbackRes struct {
}

View File

@@ -1,63 +0,0 @@
package dto
import (
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
// CreatePromptReq 添加提示词配置(默认启用)
type CreatePromptReq struct {
g.Meta `path:"/createPrompt" method:"post" tags:"提示词管理" summary:"创建提示词配置" dc:"创建新的模型提示词配置(默认启用)"`
ModelTypeId int `p:"modelTypeId" json:"modelTypeId" v:"required#modelTypeId不能为空" dc:"模型分类ID"`
ModelType string `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类别/模型类型"`
PromptInfo any `p:"promptInfo" json:"promptInfo" v:"required#promptInfo不能为空" dc:"数据库定义的表单规则数据JSON"`
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" v:"required#responseJsonSchema不能为空" dc:"模型返回表单 JSON 格式约束"`
// Version 预留字段:先不使用,但表结构保留
Version string `p:"version" json:"version" dc:"版本号(预留)"`
}
type CreatePromptRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// UpdatePromptReq 更新提示词配置
type UpdatePromptReq struct {
g.Meta `path:"/updatePrompt" method:"put" tags:"提示词管理" summary:"更新提示词配置" dc:"更新指定ID的提示词配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID可选更新"`
ModelType *string `p:"modelType" json:"modelType" dc:"模型类别/模型类型(可选更新)"`
PromptInfo any `p:"promptInfo" json:"promptInfo" dc:"数据库定义的表单规则数据JSON可选更新"`
ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" dc:"模型返回表单 JSON 格式约束(可选更新)"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-禁用1-启用(可选更新)"`
Version *string `p:"version" json:"version" dc:"版本号(预留,可选更新)"`
}
// DeletePromptReq 删除提示词配置
type DeletePromptReq struct {
g.Meta `path:"/deletePrompt" method:"delete" tags:"提示词管理" summary:"删除提示词配置" dc:"删除指定ID的提示词配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
// GetPromptReq 获取提示词配置详情
type GetPromptReq struct {
g.Meta `path:"/getPrompt" method:"get" tags:"提示词管理" summary:"获取提示词配置" dc:"根据ID获取提示词配置详情"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
type GetPromptRes struct {
Prompt any `json:"prompt" dc:"提示词配置详情"`
}
// ListPromptReq 配置列表
type ListPromptReq struct {
g.Meta `path:"/listPrompt" method:"post" tags:"提示词管理" summary:"提示词配置列表" dc:"分页获取提示词配置列表"`
Page *beans.Page `p:"page" json:"page" dc:"分页参数"`
ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID可选"`
ModelType string `p:"modelType" json:"modelType" dc:"模型类型名称(可选,模糊查询)"`
}
type ListPromptRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -2,6 +2,34 @@ package entity
import "gitea.com/red-future/common/beans" import "gitea.com/red-future/common/beans"
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
}
type asynchModelCol struct { type asynchModelCol struct {
beans.SQLBaseCol beans.SQLBaseCol
ModelName string ModelName string
@@ -55,31 +83,3 @@ var AsynchModelCol = asynchModelCol{
AutoCleanSeconds: "auto_clean_seconds", AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark", Remark: "remark",
} }
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
}

View File

@@ -1,39 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
type promptConfigCol struct {
beans.SQLBaseCol
ModelTypeId string
ModelType string
PromptInfo string
ResponseJsonSchema string
Enabled string
Version string
}
var PromptConfigCol = promptConfigCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelTypeId: "model_type_id",
ModelType: "model_type",
PromptInfo: "prompt_info",
ResponseJsonSchema: "response_json_schema",
Enabled: "enabled",
Version: "version",
}
// PromptConfig 模型提示词配置
//
// 说明:
// - prompt_info 使用 JSONB 保存(对外用 json 传输)
// - response_json_schema 为模型返回 JSON 格式约束
// - enabled1启用/0禁用
type PromptConfig struct {
beans.SQLBaseDO `orm:",inline"`
ModelTypeId int `orm:"model_type_id" json:"modelTypeId"`
ModelType string `orm:"model_type" json:"modelType"`
PromptInfo any `orm:"prompt_info" json:"promptInfo"`
ResponseJsonSchema any `orm:"response_json_schema" json:"responseJsonSchema"`
Enabled int `orm:"enabled" json:"enabled"`
Version string `orm:"version" json:"version"`
}

View File

@@ -2,6 +2,14 @@ package entity
import "gitea.com/red-future/common/beans" import "gitea.com/red-future/common/beans"
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
}
type composeSessionCol struct { type composeSessionCol struct {
beans.SQLBaseCol beans.SQLBaseCol
SessionId string SessionId string
@@ -17,11 +25,3 @@ var ComposeSessionCol = composeSessionCol{
ResponseContent: "response_content", ResponseContent: "response_content",
Remark: "remark", Remark: "remark",
} }
type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"`
}

View File

@@ -2,6 +2,20 @@ package entity
import "gitea.com/red-future/common/beans" import "gitea.com/red-future/common/beans"
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
LimitWords int `orm:"limit_words" json:"limitWords"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
CallbackPayload any `orm:"callback_payload" json:"callbackPayload"`
ModelResult any `orm:"model_result" json:"modelResult"`
Messages any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
}
type composeTaskCol struct { type composeTaskCol struct {
beans.SQLBaseCol beans.SQLBaseCol
TaskId string TaskId string
@@ -29,17 +43,3 @@ var ComposeTaskCol = composeTaskCol{
Status: "status", Status: "status",
ErrorMessage: "error_message", ErrorMessage: "error_message",
} }
type ComposeTask struct {
beans.SQLBaseDO `orm:",inline"`
TaskId string `orm:"task_id" json:"taskId"`
ModelName string `orm:"model_name" json:"modelName"`
SkillName string `orm:"skill_name" json:"skillName"`
LimitWords int `orm:"limit_words" json:"limitWords"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
CallbackPayload any `orm:"callback_payload" json:"callbackPayload"`
ModelResult any `orm:"model_result" json:"modelResult"`
Messages any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"`
}

View File

@@ -0,0 +1,49 @@
package entity
import "gitea.com/red-future/common/beans"
// ProviderProtocol 模型协议映射配置
type ProviderProtocol struct {
beans.SQLBaseDO `orm:",inherit"`
// 业务字段
ProviderName string `orm:"provider_name" json:"providerName"`
TargetField string `orm:"target_field" json:"targetField"`
MergeOrder any `orm:"merge_order" json:"mergeOrder"`
RoleMapping any `orm:"role_mapping" json:"roleMapping"`
ContentMapping any `orm:"content_mapping" json:"contentMapping"`
Capabilities any `orm:"capabilities" json:"capabilities"`
RequestTemplate any `orm:"request_template" json:"requestTemplate"`
SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"`
Status int `orm:"status" json:"status"`
Remark string `orm:"remark" json:"remark"`
}
// providerProtocolCol 列名
type providerProtocolCol struct {
beans.SQLBaseCol
ProviderName string
TargetField string
MergeOrder string
RoleMapping string
ContentMapping string
Capabilities string
RequestTemplate string
SystemPromptTemplate string
Status string
Remark string
}
// ProviderProtocolCol 列名常量
var ProviderProtocolCol = providerProtocolCol{
SQLBaseCol: beans.DefSQLBaseCol,
ProviderName: "provider_name",
TargetField: "target_field",
MergeOrder: "merge_order",
RoleMapping: "role_mapping",
ContentMapping: "content_mapping",
Capabilities: "capabilities",
RequestTemplate: "request_template",
SystemPromptTemplate: "system_prompt_template",
Status: "status",
Remark: "remark",
}

View File

@@ -1,144 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/model/dto"
"prompts-core/model/entity"
"strings"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
}
// 获取构建提示词
func GetBuildPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
}
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
messages := []map[string]any{}
switch req.BuildType {
//构建提示词请求
case 1:
//1. 构建系统提示词
messages = append(messages, map[string]any{
"role": "system",
"content": promptBuild(ctx, req, model),
})
// 2. 构建历史会话提示词
for _, msg := range history {
role := gconv.String(msg["role"])
content := gconv.String(msg["content"])
if role != "user" && role != "assistant" {
continue
}
messages = append(messages, map[string]any{
"role": role,
"content": content,
})
}
// 3. 当前用户问题(原来的最后一条)
messages = append(messages, map[string]any{
"role": "user",
"content": buildUserPrompt(ctx, req, GetModelPrompt(ctx, model.ModelType)),
})
//构建节点请求
case 2:
messages = append(messages, map[string]any{
"role": "user",
"content": NodeBuid(ctx, req),
})
default:
return nil, errors.New("不支持的构建类型")
}
// 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": map[string]any{
"model": chatModel.ModelName,
"messages": messages,
"stream": false,
},
}, nil
}
// ============================================
// 构建用户提示词
// ============================================
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName,
//数据库提示信息
"promptInfo": prompt,
// 系统表单
"form": req.Form,
// 用户表单
"userForm": req.UserForm,
//文件url
"userFiles": req.UserFiles,
//解读文件(只支持可读类型 如xmljson,yaml
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
//skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
"skills": SkillMdContent(ctx, req.SkillName),
}
return mustMarshal(payload)
}
// promptBuild 提示词构建
func promptBuild(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) string {
// 1. 从配置文件读取提示词模板
promptTpl := GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
// 2. 构建字段映射说明
mappingBytes, _ := json.Marshal(model.RequestMapping)
mappingStr := string(mappingBytes)
var mapping map[string]string
_ = json.Unmarshal(mappingBytes, &mapping)
var fieldDesc strings.Builder
for key, path := range mapping {
fieldDesc.WriteString(fmt.Sprintf("- %s → %s\n", key, path))
}
// 3. 拼接 UserForm 全文(必须完整阅读)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
// 4. 双表单信息
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, formToJSON(req.Form), userFormFullText)
// 5. 格式化最终提示词(替换配置里的 %s
return fmt.Sprintf(promptTpl, mappingStr, fieldDesc.String(), formInfo)
}
// NodeBuid 节点构建
func NodeBuid(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
formStr := formToJSON(req.Form)
userFormStr := formToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -1,9 +1,10 @@
package service package gateway
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"prompts-core/common/util"
commonHttp "gitea.com/red-future/common/http" commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/os/gtime"
@@ -19,10 +20,10 @@ type CreateTaskReq struct {
ErrorMsg string `json:"error_msg"` ErrorMsg string `json:"error_msg"`
} }
// createGatewayTask 调用 model-gateway 异步任务并同步等待结果 // CreateGatewayTask 创建网关异步任务
func createGatewayTask(ctx context.Context, payload map[string]any) (string, error) { func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
fullURL := "model-gateway/task/createTask" fullURL := "model-gateway/task/createTask"
headers := forwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var req CreateTaskReq var req CreateTaskReq
body, err := json.Marshal(payload) body, err := json.Marshal(payload)
if err != nil { if err != nil {
@@ -34,15 +35,16 @@ func createGatewayTask(ctx context.Context, payload map[string]any) (string, err
return req.TaskId, nil return req.TaskId, nil
} }
// GetTaskResultRes 任务结果响应
type GetTaskResultRes struct { type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"` OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
State int `json:"state" dc:"任务状态"` State int `json:"state" dc:"任务状态"`
} }
// queryGatewayTaskState 查询网关任务状态 // QueryGatewayTaskState 查询网关任务状态
func queryGatewayTaskState(ctx context.Context, taskID string) (int, error) { func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID) fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
headers := forwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var req GetTaskResultRes var req GetTaskResultRes
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil { if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
return 0, err return 0, err
@@ -56,16 +58,16 @@ type SkillUserVO struct {
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
FileName string `json:"fileName"` FileName string `json:"fileName"`
FileUrl string `json:"fileUrl"` // html 后缀 FileUrl string `json:"fileUrl"`
CreatedAt *gtime.Time `json:"createdAt"` CreatedAt *gtime.Time `json:"createdAt"`
UpdatedAt *gtime.Time `json:"updatedAt"` UpdatedAt *gtime.Time `json:"updatedAt"`
ImgAddressPrefix string `json:"imgAddressPrefix"` // htmml 前缀 ImgAddressPrefix string `json:"imgAddressPrefix"`
} }
// GetSkillUser 根据 name 获取技能用户信息 // GetSkillUser 获取技能用户信息
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name) fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
headers := forwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp SkillUserVO var resp SkillUserVO
var req struct{} var req struct{}
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil { if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {

View File

@@ -0,0 +1,112 @@
package prompt
import (
"context"
"errors"
"fmt"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto/prompt"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建返回请求
func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
ir := NewPromptIR()
// 1. 统一 Prompt IR
switch req.BuildType {
case 1: //构建提示词请求
ir.AddSystem(promptBuild(ctx, req, model))
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType)))
case 2: //构建节点请求
ir.AddUser(NodeBuild(ctx, req))
default:
return nil, errors.New("不支持的构建类型")
}
// 2. 获取协议配置
protocol, err := GetProtocolByProvider(ctx, "qwen")
if err != nil {
return nil, err
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
// 3. 编译为 Provider Request
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, err
}
// 4. 构建请求体
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
}
// promptBuild 构建系统提示词
func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "qwen",
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
var userFormContent strings.Builder
for k, v := range req.UserForm {
userFormContent.WriteString(fmt.Sprintf("%s=%v", k, v))
}
userFormFullText := strings.TrimSuffix(userFormContent.String(), "")
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormFullText)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo)
}
// 构建用户提示词
func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string {
payload := map[string]any{
"model": req.ModelName, // 请求模型名称
"promptInfo": prompt, // 数据库提示信息
"form": req.Form, // 系统表单
"userForm": req.UserForm, // 用户表单
"userFiles": req.UserFiles, //文件url
"userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如xmljson,yaml
"skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容)
}
return util.MustMarshal(payload)
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx, req.BuildType)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.FormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}

View File

@@ -1,28 +1,28 @@
package service package prompt
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"prompts-core/dao"
"prompts-core/model/entity"
"strings" "strings"
"time" "time"
"prompts-core/common/util"
"prompts-core/consts/public" "prompts-core/consts/public"
"prompts-core/dao" promptDto "prompts-core/model/dto/prompt"
"prompts-core/model/dto" "prompts-core/service/gateway"
"prompts-core/model/entity"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
) )
// ============================================ // ComposeMessages 核心拼接提示词主流程
// 核心业务流程 func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) {
// ============================================
// ComposeMessages 拼接提示词主流程
func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
var ( var (
epicycleId int64 epicycleId int64
taskID string taskID string
@@ -32,7 +32,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
taskRecord *entity.ComposeTask taskRecord *entity.ComposeTask
) )
// 获取模型信息 // 获取模型信息
chatModel, model, err := s.GetModelMessage(ctx, req) chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -42,18 +42,18 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
case 1: case 1:
maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int()
//1. 获取历史会话 //1. 获取历史会话
history, err = Session.GetHistoryMessages(ctx, req.SessionId) history, err = GetHistoryMessages(ctx, req.SessionId)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err) g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil // 出错就用空的,不影响主流程 history = nil // 出错就用空的,不影响主流程
} }
// 重试循环 // 重试循环
for attempt := 0; attempt <= maxRetryTimes; attempt++ { for attempt := 0; attempt <= 0; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes)
} }
// 2. 调用推理模型 // 2. 调用推理模型
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, history) taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err)
continue continue
@@ -64,7 +64,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
TaskId: taskID, TaskId: taskID,
ModelName: req.ModelName, ModelName: req.ModelName,
SkillName: req.SkillName, SkillName: req.SkillName,
RequestPayload: mustMarshal(req), RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending, Status: public.ComposeStatusPending,
}) })
if err != nil { if err != nil {
@@ -73,14 +73,14 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
} }
// 4. 等待结果 // 4. 等待结果
taskRecord, err = s.waitForResult(ctx, taskID) taskRecord, err = waitForResult(ctx, taskID)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue continue
} }
// 校验结果 // 校验结果
message = s.parsePromptBuild(taskRecord, chatModel) message = parsePromptBuild(taskRecord, chatModel)
if message != nil && isMessageValid(message) { if message != nil && util.IsMessageValid(message) {
break break
} }
g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1)
@@ -97,7 +97,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
//节点构建 //节点构建
case 2: case 2:
//1. 调用推理模型 //1. 调用推理模型
taskID, err = s.callInferenceModel(ctx, req, chatModel, model, nil) taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -106,115 +106,41 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes
TaskId: taskID, TaskId: taskID,
ModelName: req.ModelName, ModelName: req.ModelName,
SkillName: req.SkillName, SkillName: req.SkillName,
RequestPayload: mustMarshal(req), RequestPayload: util.MustMarshal(req),
Status: public.ComposeStatusPending, Status: public.ComposeStatusPending,
}) })
//5. 等待结果 //5. 等待结果
taskRecord, err := s.waitForResult(ctx, taskID) taskRecord, err := waitForResult(ctx, taskID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Println("构建节点前", taskRecord) message = parseNodeBuild(taskRecord)
message = s.parseNodeBuild(taskRecord)
fmt.Println("构建节点后", message)
default: default:
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId, SessionId: req.SessionId,
Remark: req.Cause, Remark: req.Cause,
}) })
return &dto.ComposeMessagesRes{ return &promptDto.ComposeMessagesRes{
EpicycleId: epicycleId, EpicycleId: epicycleId,
}, nil }, nil
} }
return &dto.ComposeMessagesRes{ return &promptDto.ComposeMessagesRes{
Messages: message, Messages: message,
EpicycleId: epicycleId, EpicycleId: epicycleId,
}, nil }, nil
} }
func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId)
if err != nil {
return err
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := parseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed,
entity.ComposeTaskCol.ErrorMessage: err.Error(),
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess,
entity.ComposeTaskCol.Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID)
if err != nil {
return nil, err
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// GetModelMessage 获取模型信息 // GetModelMessage 获取模型信息
func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, err
}
// 1. 获取当前用户的会话模型 // 1. 获取当前用户的会话模型
chatModel, err := dao.Model.GetByIsChatModel(ctx) chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: 1,
})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -222,18 +148,21 @@ func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMes
return nil, nil, errors.New("当前没有对话模型,请添加") return nil, nil, errors.New("当前没有对话模型,请添加")
} }
// 2. 获取要构建的模型信息 // 2. 获取要构建的模型信息
model, err := dao.Model.GetByModelName(ctx, req.ModelName) aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if model == nil { if aiModel == nil {
return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName) return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName)
} }
return chatModel, model, nil return chatModel, aiModel, nil
} }
// callInferenceModel 调用推理模型 // callInferenceModel 调用推理模型
func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) {
// 构建推理模型请求 // 构建推理模型请求
taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history) taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history)
if err != nil { if err != nil {
@@ -241,7 +170,7 @@ func (s *promptService) callInferenceModel(ctx context.Context, req *dto.Compose
} }
// 创建网关任务 // 创建网关任务
taskID, err := createGatewayTask(ctx, taskReq) taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil { if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err) return "", fmt.Errorf("创建网关任务失败: %w", err)
} }
@@ -253,10 +182,8 @@ func (s *promptService) callInferenceModel(ctx context.Context, req *dto.Compose
return taskID, nil return taskID, nil
} }
// ============================================ // waitForResult 等待结果
// 步骤6等待结果 func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
// ============================================
func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout) deadline := time.Now().Add(timeout)
@@ -271,7 +198,9 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti
} }
// 1. 查数据库 // 1. 查数据库
record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil { if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 ===================== // ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
@@ -292,7 +221,7 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti
} }
// 2. 查网关状态 // 2. 查网关状态
state, err := queryGatewayTaskState(ctx, taskID) state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil { if err != nil {
// 网关不可达不终止,继续轮询 // 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
@@ -301,16 +230,24 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti
case 2: // 网关成功 case 2: // 网关成功
// 网关已成功,主动更新数据库 // 网关已成功,主动更新数据库
if record != nil { if record != nil {
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, TaskId: taskID,
Status: public.ComposeStatusSuccess,
}) })
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
} }
case 3: // 网关失败 case 3: // 网关失败
if record != nil { if record != nil {
dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
entity.ComposeTaskCol.Status: public.ComposeStatusFailed, TaskId: taskID,
entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败", Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
}) })
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
} }
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
} }
@@ -331,7 +268,7 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti
} }
// parsePromptBuild 解析提示词构建结果BuildType == 1 // parsePromptBuild 解析提示词构建结果BuildType == 1
func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any { func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any {
if taskRecord == nil { if taskRecord == nil {
return nil return nil
} }
@@ -394,7 +331,7 @@ func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *
} }
// parseNodeBuild 解析节点构建结果BuildType == 2 // parseNodeBuild 解析节点构建结果BuildType == 2
func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any { func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any {
if taskRecord == nil { if taskRecord == nil {
return nil return nil
} }
@@ -414,3 +351,90 @@ func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[strin
} }
return result return result
} }
// Callback 回调处理
func Callback(ctx context.Context, req *promptDto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
// ============ 先查任务是否存在 ============
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
if err != nil {
return err
}
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
// ============ 根据状态区分处理 ============
if req.State == 3 {
// 失败:直接更新状态
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
return err
}
// ======================================
// 成功:解析模型输出
result, err := util.ParseOutput(req.Text)
if err != nil {
_, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
if updateErr != nil {
g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr)
}
return err
}
// ============ result 可能为 nil ============
var messages any
if result != nil {
messages = result
}
// =======================================
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, err
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
// 如果 Messages 是字符串,反序列化为 JSON 数组
messages := record.Messages
if str, ok := messages.(string); ok && str != "" {
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
messages = parsed
}
}
return &promptDto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}

View File

@@ -1,4 +1,4 @@
package service package prompt
import ( import (
"archive/zip" "archive/zip"
@@ -7,52 +7,16 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"path/filepath"
"regexp"
"strings" "strings"
"time" "time"
"prompts-core/common/util"
"prompts-core/service/gateway"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
) )
// ============================================ // FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
// 文件处理(配置直接内联 + zip 支持)
// ============================================
// 允许的文本类 MIME 类型前缀
var allowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
// 禁止的文件扩展名
var bannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
// FetchFileTexts 从 URL 列表获取文件内容(支持 zip 内文件)
func FetchFileTexts(ctx context.Context, urls []string) map[string]string { func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
result := make(map[string]string) result := make(map[string]string)
@@ -65,16 +29,16 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
} }
for _, rawURL := range urls { for _, rawURL := range urls {
url := sanitizeURL(rawURL) url := util.SanitizeURL(rawURL)
if url == "" { if url == "" {
continue continue
} }
if isBannedExtension(url) { if util.IsBannedExtension(url) {
continue continue
} }
if isZipExtension(url) { if util.IsZipExtension(url) {
zipTexts := fetchZipFileTexts(ctx, client, url) zipTexts := fetchZipFileTexts(ctx, client, url)
for k, v := range zipTexts { for k, v := range zipTexts {
result[k] = v result[k] = v
@@ -91,21 +55,14 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
continue continue
} }
text = cleanSymbols(text) text = util.CleanSymbols(text)
result[url] = text result[url] = text
} }
return result return result
} }
func isZipExtension(url string) bool { // fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext == ".zip"
}
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string { func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string) result := make(map[string]string)
@@ -130,11 +87,11 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
fileName := file.Name fileName := file.Name
if isBannedExtension(fileName) { if util.IsBannedExtension(fileName) {
continue continue
} }
if isZipExtension(fileName) { if util.IsZipExtension(fileName) {
continue continue
} }
@@ -150,11 +107,11 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
} }
contentType := http.DetectContentType(content) contentType := http.DetectContentType(content)
if !isReadableContentType(contentType) { if !util.IsReadableContentType(contentType) {
continue continue
} }
text := cleanSymbols(string(content)) text := util.CleanSymbols(string(content))
if text == "" { if text == "" {
continue continue
} }
@@ -166,6 +123,7 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map
return result return result
} }
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) { func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
@@ -185,35 +143,7 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error
return io.ReadAll(io.LimitReader(resp.Body, maxSize)) return io.ReadAll(io.LimitReader(resp.Body, maxSize))
} }
func isBannedExtension(url string) bool { // fetchFileContent 获取单个文本文件内容
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return bannedExtensions[ext]
}
func isReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range allowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}
func cleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) { func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
@@ -231,7 +161,7 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
} }
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
if !isReadableContentType(contentType) { if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("unreadable content-type: %s", contentType) return "", fmt.Errorf("unreadable content-type: %s", contentType)
} }
@@ -247,22 +177,15 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
return strings.TrimSpace(string(body)), nil return strings.TrimSpace(string(body)), nil
} }
func sanitizeURL(raw string) string {
s := strings.TrimSpace(raw)
s = strings.Trim(s, "`\"")
return s
}
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容 // SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
func SkillMdContent(ctx context.Context, skillName string) string { func SkillMdContent(ctx context.Context, skillName string) string {
// 1. 请求接口获取 SkillUserVO skillResp, err := gateway.GetSkillUser(ctx, skillName)
skillResp, err := GetSkillUser(ctx, skillName)
if err != nil { if err != nil {
return "" return ""
} }
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
// 2. 下载 zip 文件
client := &http.Client{ client := &http.Client{
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second, Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
} }
@@ -274,7 +197,6 @@ func SkillMdContent(ctx context.Context, skillName string) string {
return "" return ""
} }
// 3. 解压 zip 并提取所有 md 文件内容
mdContents, err := extractMdFiles(ctx, zipBytes) mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil { if err != nil {
return "" return ""
@@ -284,7 +206,6 @@ func SkillMdContent(ctx context.Context, skillName string) string {
return "" return ""
} }
// 4. 拼接所有 md 内容
var builder strings.Builder var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name)) builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" { if skillResp.Description != "" {

View File

@@ -0,0 +1,264 @@
package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
}
// Segment 消息片段
type Segment struct {
Type string `json:"type"` // text/image
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
}
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
return ir
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
return ir
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
return ir
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
// 1. 系统消息
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
// 2. 历史消息
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
// 3. 用户消息
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
// GetProtocolByProvider 根据 provider_name 获取协议配置
func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) {
entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: providerName,
Status: 1,
})
if err != nil || entity == nil {
return nil, err
}
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
return parseProtocol(entity), nil
}
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
}
// MergeOrder: any → []string
if e.MergeOrder != nil {
b, _ := json.Marshal(e.MergeOrder)
json.Unmarshal(b, &p.MergeOrder)
}
// RoleMapping: any → map[string]string
if e.RoleMapping != nil {
b, _ := json.Marshal(e.RoleMapping)
json.Unmarshal(b, &p.RoleMapping)
}
// ContentMapping: any → ContentMapping
if e.ContentMapping != nil {
b, _ := json.Marshal(e.ContentMapping)
json.Unmarshal(b, &p.ContentMapping)
}
// RequestTemplate: any → map[string]any
if e.RequestTemplate != nil {
b, _ := json.Marshal(e.RequestTemplate)
json.Unmarshal(b, &p.RequestTemplate)
}
fmt.Printf("parseProtocol: %+v\n", p)
return p
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"` // direct/parts
Field string `json:"field"` // content/text
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
// 1. 按 merge_order 拼接消息
messages := mergeByOrder(ir, p.MergeOrder)
// 2. 角色映射
messages = mapRoles(messages, p.RoleMapping)
// 3. 内容字段映射
messages = mapContent(messages, p.ContentMapping)
// 4. 按 target_field + request_template 构建请求体
return buildRequest(messages, p, chatModel), nil
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
}
return messages
}
// mapRoles 角色映射
func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any {
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
// Gemini 格式: {"parts": [{"text": "..."}]}
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
// direct: {"content": "..."}
msg[cm.Field] = content
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
// 替换 {{model}}
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// 替换 {{messages}}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}

View File

@@ -1,4 +1,4 @@
package service package prompt
import ( import (
"context" "context"
@@ -12,7 +12,7 @@ import (
// ==================== Redis 操作 ==================== // ==================== Redis 操作 ====================
// saveToRedis 保存会话数据到Redis // saveToRedis 保存会话数据到Redis
func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error {
key := fmt.Sprintf("chat:session:%s", sessionId) key := fmt.Sprintf("chat:session:%s", sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
@@ -50,7 +50,7 @@ func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requ
} }
// getFromRedis 从Redis获取会话历史 // getFromRedis 从Redis获取会话历史
func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
key := fmt.Sprintf("chat:session:%s", sessionId) key := fmt.Sprintf("chat:session:%s", sessionId)
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
@@ -82,8 +82,8 @@ func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]
} }
// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用) // GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用)
func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) { func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) {
historyData, err := s.getFromRedis(ctx, sessionId) historyData, err := getFromRedis(ctx, sessionId)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err) return nil, fmt.Errorf("获取历史会话失败: %w", err)
} }

View File

@@ -1,24 +1,22 @@
package service package prompt
import ( import (
"context" "context"
"fmt" "fmt"
"prompts-core/dao" sessionDao "prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity" "prompts-core/model/entity"
"prompts-core/common/util"
sessionDto "prompts-core/model/dto/prompt"
"gitea.com/red-future/common/beans" "gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
var Session = &sessionService{} func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) {
type sessionService struct{}
func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
// 1. 解析AI返回的文本 // 1. 解析AI返回的文本
result, err := parseOutput(req.Text) result, err := util.ParseOutput(req.Text)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err) g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err return nil, err
@@ -26,7 +24,7 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa
// 2. 更新数据库 // 2. 更新数据库
result["role"] = "assistant" result["role"] = "assistant"
_, err = dao.ComposeSession.Update(ctx, &entity.ComposeSession{ _, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: result, ResponseContent: result,
}) })
@@ -36,17 +34,19 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa
} }
// 3. 获取当前轮次完整数据 // 3. 获取当前轮次完整数据
session, err := dao.ComposeSession.GetById(ctx, req.EpicycleId) session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
})
if err != nil { if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err) g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, err return nil, err
} }
// 4. 转换 json 并存入 Redis // 4. 转换 json 并存入 Redis
requestMessages := convertToMessages(session.RequestContent) requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := convertToMessages(session.ResponseContent) responseMessages := util.ConvertToMessages(session.ResponseContent)
if err = s.saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err) session.SessionId, session.Id, err)
return nil, err return nil, err
@@ -54,21 +54,23 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages)) session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &beans.ResponseEmpty{}, nil return &sessionDto.SessionCallbackRes{}, nil
} }
// GetHistoryMessages 获取历史信息 // GetHistoryMessages 获取历史信息
func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
// 1. 先从 Redis 拿 // 1. 先从 Redis 拿
redisHistory, err := s.GetSessionHistoryForInference(ctx, sessionId) redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
if err == nil && len(redisHistory) > 0 { if err == nil && len(redisHistory) > 0 {
return redisHistory, nil return redisHistory, nil
} }
// 2. Redis 没有 → fallback DB // 2. Redis 没有 → fallback DB
sessions, err := dao.ComposeSession.GetListBySessionId(ctx, sessionId, maxRounds) sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{
SessionId: sessionId,
}, 1, maxRounds)
if err != nil { if err != nil {
return nil, fmt.Errorf("DB获取历史失败: %w", err) return nil, fmt.Errorf("DB获取历史失败: %w", err)
} }
@@ -77,7 +79,7 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin
for _, session := range sessions { for _, session := range sessions {
// request // request
reqMsgs := convertToMessages(session.RequestContent) reqMsgs := util.ConvertToMessages(session.RequestContent)
for _, m := range reqMsgs { for _, m := range reqMsgs {
role := gconv.String(m["role"]) role := gconv.String(m["role"])
if role == "user" || role == "assistant" { if role == "user" || role == "assistant" {
@@ -86,7 +88,7 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin
} }
// response // response
respMsgs := convertToMessages(session.ResponseContent) respMsgs := util.ConvertToMessages(session.ResponseContent)
for _, m := range respMsgs { for _, m := range respMsgs {
if m["role"] == nil { if m["role"] == nil {
m["role"] = "assistant" m["role"] = "assistant"
@@ -97,15 +99,15 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin
// 3. 回写 Redis // 3. 回写 Redis
for _, session := range sessions { for _, session := range sessions {
reqMsgs := convertToMessages(session.RequestContent) reqMsgs := util.ConvertToMessages(session.RequestContent)
respMsgs := convertToMessages(session.ResponseContent) respMsgs := util.ConvertToMessages(session.ResponseContent)
for i := range respMsgs { for i := range respMsgs {
if respMsgs[i]["role"] == nil { if respMsgs[i]["role"] == nil {
respMsgs[i]["role"] = "assistant" respMsgs[i]["role"] = "assistant"
} }
} }
if len(reqMsgs) > 0 || len(respMsgs) > 0 { if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = s.saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) _ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
} }
} }
return messages, nil return messages, nil

View File

@@ -1,92 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
var Prompt = &promptService{}
type promptService struct{}
func (s *promptService) Create(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) {
// promptInfo 兜底校验:必须可序列化为 JSON
if req.PromptInfo == nil {
return nil, errors.New("promptInfo不能为空")
}
if _, err := json.Marshal(req.PromptInfo); err != nil {
return nil, errors.New("promptInfo不是合法JSON")
}
if req.ResponseJsonSchema == nil {
return nil, errors.New("responseJsonSchema不能为空")
}
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
return nil, errors.New("responseJsonSchema不是合法JSON")
}
m := &entity.PromptConfig{
ModelTypeId: req.ModelTypeId,
ModelType: req.ModelType,
PromptInfo: req.PromptInfo,
ResponseJsonSchema: req.ResponseJsonSchema,
Enabled: 1,
Version: req.Version,
}
id, err := dao.Prompt.Insert(ctx, m)
if err != nil {
return nil, err
}
return &dto.CreatePromptRes{ID: id}, nil
}
func (s *promptService) Update(ctx context.Context, req *dto.UpdatePromptReq) error {
data := map[string]any{}
if req.ModelTypeId != nil && *req.ModelTypeId > 0 {
data[entity.PromptConfigCol.ModelTypeId] = *req.ModelTypeId
}
if req.ModelType != nil && *req.ModelType != "" {
data[entity.PromptConfigCol.ModelType] = *req.ModelType
}
if req.PromptInfo != nil {
if _, err := json.Marshal(req.PromptInfo); err != nil {
return errors.New("promptInfo不是合法JSON")
}
data[entity.PromptConfigCol.PromptInfo] = req.PromptInfo
}
if req.ResponseJsonSchema != nil {
if _, err := json.Marshal(req.ResponseJsonSchema); err != nil {
return errors.New("responseJsonSchema不是合法JSON")
}
data[entity.PromptConfigCol.ResponseJsonSchema] = req.ResponseJsonSchema
}
if req.Enabled != nil {
data[entity.PromptConfigCol.Enabled] = *req.Enabled
}
if req.Version != nil {
data[entity.PromptConfigCol.Version] = *req.Version
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.Prompt.UpdateByID(ctx, req.ID, data)
return err
}
func (s *promptService) Delete(ctx context.Context, id int64) error {
_, err := dao.Prompt.DeleteByID(ctx, id)
return err
}
func (s *promptService) Get(ctx context.Context, id int64) (*entity.PromptConfig, error) {
return dao.Prompt.GetByID(ctx, id)
}
func (s *promptService) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) {
return dao.Prompt.List(ctx, pageNum, pageSize, modelTypeID, modelTypeLike)
}

View File

@@ -1,65 +0,0 @@
// utils 工具函数
package service
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ============================================
// json 相关处理
// ============================================
// parseOutput 解析模型输出为 JSON 格式
func parseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
func convertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
// 1. 如果有 messages
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
// 2. 否则当成单条 message
return []map[string]any{
j.Map(),
}
}
// isMessageValid 校验推理结果是否合法
func isMessageValid(message map[string]any) bool {
if message == nil {
return false
}
return true
}
func formToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
func mustMarshal(v any) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}

View File

@@ -1,117 +1,130 @@
-- prompts-core 核心表pgsql
-- 说明字段风格尽量与参考项目一致tenant/creator/updater/created_at/updated_at/deleted_at
-- prompts_model_prompt 模型提示词配置表
CREATE TABLE IF NOT EXISTS prompts_model_prompt (
-- 基础字段(与 common/db/gfdb 的 Hook 约定保持一致)
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL, -- 创建人
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updater VARCHAR(64) NOT NULL, -- 更新人
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
deleted_at TIMESTAMP(6), -- 删除时间(软删)
-- 业务字段(按你当前的最小字段集)
model_type_id INT NOT NULL DEFAULT 0, -- 模型分类ID
model_type VARCHAR(64) NOT NULL, -- 模型类别
prompt_info JSONB NOT NULL DEFAULT '{}'::jsonb, -- 提示词信息JSON
response_json_schema JSONB NOT NULL DEFAULT '{}'::jsonb, -- 模型返回表单 JSON 格式约束
enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用1启用/0禁用
version VARCHAR(64) NOT NULL DEFAULT '' -- 版本号(预留)
);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_tenant_id ON prompts_model_prompt(tenant_id);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type_id ON prompts_model_prompt(model_type_id);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type ON prompts_model_prompt(model_type);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_enabled ON prompts_model_prompt(enabled);
CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_deleted_at ON prompts_model_prompt(deleted_at);
COMMENT ON TABLE prompts_model_prompt IS '模型提示词配置表';
COMMENT ON COLUMN prompts_model_prompt.id IS '主键ID非自增';
COMMENT ON COLUMN prompts_model_prompt.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_model_prompt.creator IS '创建人';
COMMENT ON COLUMN prompts_model_prompt.created_at IS '创建时间';
COMMENT ON COLUMN prompts_model_prompt.updater IS '更新人';
COMMENT ON COLUMN prompts_model_prompt.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_model_prompt.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_model_prompt.model_type_id IS '模型分类ID';
COMMENT ON COLUMN prompts_model_prompt.model_type IS '模型类别';
COMMENT ON COLUMN prompts_model_prompt.prompt_info IS '提示词信息JSON';
COMMENT ON COLUMN prompts_model_prompt.response_json_schema IS '模型返回表单 JSON 格式约束';
COMMENT ON COLUMN prompts_model_prompt.enabled IS '是否启用1启用/0禁用';
COMMENT ON COLUMN prompts_model_prompt.version IS '版本号(预留)';
-- prompts_compose_task 拼接提示词任务记录表 -- prompts_compose_task 拼接提示词任务记录表
CREATE TABLE IF NOT EXISTS prompts_compose_task ( CREATE TABLE IF NOT EXISTS prompts_compose_task (
id BIGINT PRIMARY KEY, id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0, tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL, creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL, updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6), deleted_at TIMESTAMP(6),
task_id VARCHAR(64) NOT NULL, task_id VARCHAR(64) NOT NULL,
model_name VARCHAR(128) NOT NULL DEFAULT '', model_name VARCHAR(128) NOT NULL DEFAULT '',
skill_name VARCHAR(128) NOT NULL DEFAULT '', skill_name VARCHAR(128) NOT NULL DEFAULT '',
gateway_state INT NOT NULL DEFAULT 0, gateway_state INT NOT NULL DEFAULT 0,
limit_words INT NOT NULL DEFAULT 0, limit_words INT NOT NULL DEFAULT 0,
request_payload JSONB NOT NULL DEFAULT '{}'::jsonb, request_payload JSONB NOT NULL DEFAULT '{}'::jsonb,
result_text TEXT NOT NULL DEFAULT '', result_text TEXT NOT NULL DEFAULT '',
messages JSONB NOT NULL DEFAULT '[]'::jsonb, messages JSONB NOT NULL DEFAULT '{}'::jsonb,
status VARCHAR(32) NOT NULL DEFAULT 'pending', status VARCHAR(32) NOT NULL DEFAULT 'pending',
error_message TEXT NOT NULL DEFAULT '', error_message TEXT NOT NULL DEFAULT '',
oss_file VARCHAR(1024) NOT NULL DEFAULT '', oss_file VARCHAR(1024) NOT NULL DEFAULT '',
file_type VARCHAR(64) NOT NULL DEFAULT '' file_type VARCHAR(64) NOT NULL DEFAULT ''
); );
-- 索引
CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id); CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status); CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task(deleted_at); CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task
-- 注释
COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表'; COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表';
COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID'; COMMENT ON COLUMN prompts_compose_task.id IS '主键ID';
COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称'; COMMENT ON COLUMN prompts_compose_task.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称'; COMMENT ON COLUMN prompts_compose_task.creator IS '创建人';
COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态0排队/1执行/2成功/3失败/4已下载'; COMMENT ON COLUMN prompts_compose_task.created_at IS '创建时间';
COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数'; COMMENT ON COLUMN prompts_compose_task.updater IS '更新人';
COMMENT ON COLUMN prompts_compose_task.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_compose_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID';
COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称';
COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称';
COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态0排队/1执行/2成功/3失败/4已下载';
COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数';
COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容'; COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容';
COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果'; COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果';
COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages'; COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages';
COMMENT ON COLUMN prompts_compose_task.status IS '业务状态pending/success/failed'; COMMENT ON COLUMN prompts_compose_task.status IS '业务状态pending/success/failed';
COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息'; COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息';
COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址'; COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址';
COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型'; COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型';
-- prompts_compose_session 提示词历史会话表 -- prompts_compose_session 提示词历史会话表
CREATE TABLE IF NOT EXISTS prompts_compose_session ( CREATE TABLE IF NOT EXISTS prompts_compose_session (
id BIGINT PRIMARY KEY, id BIGINT NOT NULL,
tenant_id BIGINT NOT NULL DEFAULT 0, tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL, creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL, updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6), deleted_at TIMESTAMP(6),
session_id VARCHAR(64) NOT NULL, session_id VARCHAR(64) NOT NULL,
request_content JSONB NOT NULL DEFAULT '{}'::jsonb, request_content JSONB NOT NULL DEFAULT '{}'::jsonb,
response_content JSONB NOT NULL DEFAULT '{}'::jsonb, response_content JSONB NOT NULL DEFAULT '{}'::jsonb,
remark VARCHAR(500) NOT NULL DEFAULT '' remark VARCHAR(500) NOT NULL DEFAULT ''
); );
-- 索引
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_session_id ON prompts_compose_session(session_id); CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_session_id ON prompts_compose_session(session_id);
CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_deleted_at ON prompts_compose_session(deleted_at); CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_deleted_at ON prompts_compose_session(deleted_at);
-- 注释
COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表'; COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表';
COMMENT ON COLUMN prompts_compose_session.id IS '主键ID(非自增)'; COMMENT ON COLUMN prompts_compose_session.id IS '主键ID';
COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID'; COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_compose_session.creator IS '创建人'; COMMENT ON COLUMN prompts_compose_session.creator IS '创建人';
COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间'; COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间';
COMMENT ON COLUMN prompts_compose_session.updater IS '更新人'; COMMENT ON COLUMN prompts_compose_session.updater IS '更新人';
COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间'; COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)'; COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID'; COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID';
COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容JSON格式'; COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容JSON格式';
COMMENT ON COLUMN prompts_compose_session.response_content IS '返回内容JSON格式'; COMMENT ON COLUMN prompts_compose_session.response_content IS '返回内容JSON格式';
COMMENT ON COLUMN prompts_compose_session.remark IS '备注'; COMMENT ON COLUMN prompts_compose_session.remark IS '备注';
-- prompts_provider_protocol 模型协议映射配置表
CREATE TABLE IF NOT EXISTS prompts_provider_protocol (
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6),
provider_name VARCHAR(64) NOT NULL DEFAULT '',
target_field VARCHAR(64) NOT NULL DEFAULT '',
merge_order JSONB NOT NULL DEFAULT '[]'::jsonb,
role_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
content_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
request_template JSONB NOT NULL DEFAULT '{}'::jsonb,
system_prompt_template TEXT NOT NULL DEFAULT '',
user_prompt_template TEXT NOT NULL DEFAULT '',
status INT NOT NULL DEFAULT 1,
remark VARCHAR(500) NOT NULL DEFAULT ''
);
-- 索引
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_provider_name ON prompts_provider_protocol(provider_name);
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_status ON prompts_provider_protocol(status);
CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_deleted_at ON prompts_provider_protocol(deleted_at);
-- 注释
COMMENT ON TABLE prompts_provider_protocol IS '模型协议映射配置表';
COMMENT ON COLUMN prompts_provider_protocol.id IS '主键ID';
COMMENT ON COLUMN prompts_provider_protocol.tenant_id IS '租户ID';
COMMENT ON COLUMN prompts_provider_protocol.creator IS '创建人';
COMMENT ON COLUMN prompts_provider_protocol.created_at IS '创建时间';
COMMENT ON COLUMN prompts_provider_protocol.updater IS '更新人';
COMMENT ON COLUMN prompts_provider_protocol.updated_at IS '更新时间';
COMMENT ON COLUMN prompts_provider_protocol.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN prompts_provider_protocol.provider_name IS '运营商名称openai/deepseek/qwen/anthropic/gemini等';
COMMENT ON COLUMN prompts_provider_protocol.target_field IS '目标字段messages/contents/prompt';
COMMENT ON COLUMN prompts_provider_protocol.merge_order IS 'Prompt IR 拼接顺序system/history/user';
COMMENT ON COLUMN prompts_provider_protocol.role_mapping IS '角色映射system/user/assistant -> provider role';
COMMENT ON COLUMN prompts_provider_protocol.content_mapping IS '内容字段映射content/parts.text等';
COMMENT ON COLUMN prompts_provider_protocol.capabilities IS '协议能力配置system/history/tools/stream等支持情况';
COMMENT ON COLUMN prompts_provider_protocol.request_template IS '请求模板JSON结构模板';
COMMENT ON COLUMN prompts_provider_protocol.system_prompt_template IS '系统提示词模板';
COMMENT ON COLUMN prompts_provider_protocol.status IS '状态1启用/0禁用';
COMMENT ON COLUMN prompts_provider_protocol.remark IS '备注';