From c49144794da66c0c8c2d45eab4dace65bb922fcf Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Mon, 18 May 2026 19:19:17 +0800 Subject: [PATCH] =?UTF-8?q?refactor(service):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84=E5=B9=B6?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/config.go | 18 ++ common/util/files.go | 88 ++++++ {service => common/util}/headers.go | 13 +- common/util/json.go | 103 +++++++ config.yml | 27 +- consts/public/table_name.go | 12 +- controller/compose_session.go | 20 -- .../prompt/prompt_compose_controller.go | 29 ++ .../prompt/prompt_session_controller.go | 18 ++ controller/prompt_controller.go | 69 ----- dao/compose_session_dao.go | 119 ++++---- dao/compose_task_dao.go | 43 +-- dao/model_dao.go | 53 +--- dao/prompt_dao.go | 97 ------- dao/provider_protocol_dao.go | 91 ++++++ main.go | 10 +- .../prompt_compose_dto.go} | 10 +- .../prompt_session_dto.go} | 5 +- model/dto/prompt_dto.go | 63 ---- model/entity/asynch_model.go | 56 ++-- model/entity/prompt_config.go | 39 --- ..._session.go => prompts_compose_session.go} | 16 +- ...ompose_task.go => prompts_compose_task.go} | 28 +- model/entity/prompts_provider_protocol.go | 49 ++++ service/build_prompt.go | 144 --------- .../gateway_http_service.go} | 24 +- service/prompt/prompt_build_service.go | 112 +++++++ .../prompt_compose_service.go} | 274 ++++++++++-------- .../prompt_files_handle_service.go} | 117 ++------ service/prompt/prompt_ir_service.go | 264 +++++++++++++++++ .../prompt_session_redis_service.go} | 10 +- .../prompt_session_service.go} | 48 +-- service/prompt_service.go | 92 ------ service/utils.go | 65 ----- update.sql | 217 +++++++------- 35 files changed, 1281 insertions(+), 1162 deletions(-) create mode 100644 common/util/config.go create mode 100644 common/util/files.go rename {service => common/util}/headers.go (68%) create mode 100644 common/util/json.go delete mode 100644 controller/compose_session.go create mode 100644 controller/prompt/prompt_compose_controller.go create mode 100644 controller/prompt/prompt_session_controller.go delete mode 100644 controller/prompt_controller.go delete mode 100644 dao/prompt_dao.go create mode 100644 dao/provider_protocol_dao.go rename model/dto/{compose_messages_dto.go => prompt/prompt_compose_dto.go} (94%) rename model/dto/{compose_session_dto.go => prompt/prompt_session_dto.go} (83%) delete mode 100644 model/dto/prompt_dto.go delete mode 100644 model/entity/prompt_config.go rename model/entity/{compose_session.go => prompts_compose_session.go} (100%) rename model/entity/{compose_task.go => prompts_compose_task.go} (100%) create mode 100644 model/entity/prompts_provider_protocol.go delete mode 100644 service/build_prompt.go rename service/{http_service.go => gateway/gateway_http_service.go} (75%) create mode 100644 service/prompt/prompt_build_service.go rename service/{compose_service.go => prompt/prompt_compose_service.go} (69%) rename service/{files_handle.go => prompt/prompt_files_handle_service.go} (62%) create mode 100644 service/prompt/prompt_ir_service.go rename service/{session_redis_service.go => prompt/prompt_session_redis_service.go} (85%) rename service/{session_service.go => prompt/prompt_session_service.go} (60%) delete mode 100644 service/prompt_service.go delete mode 100644 service/utils.go diff --git a/common/util/config.go b/common/util/config.go new file mode 100644 index 0000000..55bfbf3 --- /dev/null +++ b/common/util/config.go @@ -0,0 +1,18 @@ +package util + +import ( + "context" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +// GetModelPrompt 获取请求模型的提示词 +func GetModelPrompt(ctx context.Context, Type int) string { + return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String() +} + +// GetBuildPrompt 获取构建提示词 +func GetBuildPrompt(ctx context.Context, Type int) string { + return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String() +} diff --git a/common/util/files.go b/common/util/files.go new file mode 100644 index 0000000..f59d410 --- /dev/null +++ b/common/util/files.go @@ -0,0 +1,88 @@ +package util + +import ( + "path/filepath" + "regexp" + "strings" +) + +// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀 +var AllowedMIMEPrefixes = []string{ + "text/", + "application/json", + "application/xml", + "application/javascript", + "application/x-yaml", + "application/yaml", + "application/toml", + "application/x-httpd-php", + "application/x-sh", + "application/x-python", + "application/x-perl", + "application/x-ruby", +} + +// BannedExtensions 禁止的文件扩展名 +var BannedExtensions = map[string]bool{ + ".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true, + ".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true, + ".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true, + ".wma": true, ".m4a": true, + ".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true, + ".flv": true, ".webm": true, + ".tar": true, ".gz": true, ".rar": true, ".7z": true, + ".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true, + ".class": true, ".pyc": true, + ".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true, + ".ppt": true, ".pptx": true, +} + +var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`) + +// SanitizeURL 清洗 URL 字符串 +func SanitizeURL(raw string) string { + s := strings.TrimSpace(raw) + s = strings.Trim(s, "`\"") + return s +} + +// CleanSymbols 清洗文本中的控制字符和多余空行 +func CleanSymbols(text string) string { + text = symbolCleaner.ReplaceAllString(text, "") + text = strings.ReplaceAll(text, "\r\n", "\n") + text = strings.ReplaceAll(text, "\r", "\n") + text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n") + return strings.TrimSpace(text) +} + +// IsBannedExtension 判断是否为禁止的文件扩展名 +func IsBannedExtension(url string) bool { + ext := strings.ToLower(filepath.Ext(url)) + if idx := strings.Index(ext, "?"); idx != -1 { + ext = ext[:idx] + } + return BannedExtensions[ext] +} + +// IsZipExtension 判断是否为 zip 文件 +func IsZipExtension(url string) bool { + ext := strings.ToLower(filepath.Ext(url)) + if idx := strings.Index(ext, "?"); idx != -1 { + ext = ext[:idx] + } + return ext == ".zip" +} + +// IsReadableContentType 判断是否为可读的文本类型 +func IsReadableContentType(contentType string) bool { + if contentType == "" { + return false + } + ct := strings.ToLower(contentType) + for _, prefix := range AllowedMIMEPrefixes { + if strings.HasPrefix(ct, prefix) { + return true + } + } + return false +} diff --git a/service/headers.go b/common/util/headers.go similarity index 68% rename from service/headers.go rename to common/util/headers.go index e83ae4d..5615d45 100644 --- a/service/headers.go +++ b/common/util/headers.go @@ -1,4 +1,4 @@ -package service +package util import ( "context" @@ -7,9 +7,8 @@ import ( "github.com/gogf/gf/v2/frame/g" ) -// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。 -// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。 -func asyncCtx(ctx context.Context) context.Context { +// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失 +func AsyncCtx(ctx context.Context) context.Context { asyncCtx := context.WithoutCancel(ctx) if r := g.RequestFromCtx(ctx); r != nil { if token := r.Header.Get("Authorization"); token != "" { @@ -25,8 +24,8 @@ func asyncCtx(ctx context.Context) context.Context { return asyncCtx } -// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。 -func forwardHeaders(ctx context.Context) map[string]string { +// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值 +func ForwardHeaders(ctx context.Context) map[string]string { headers := make(map[string]string) if token, ok := ctx.Value("token").(string); ok && token != "" { @@ -36,7 +35,7 @@ func forwardHeaders(ctx context.Context) map[string]string { headers["X-User-Info"] = x } - // 兜底:从请求头拿 + // 兜底:从请求头获取 if r := g.RequestFromCtx(ctx); r != nil { if headers["Authorization"] == "" { if token := r.Header.Get("Authorization"); token != "" { diff --git a/common/util/json.go b/common/util/json.go new file mode 100644 index 0000000..c31d600 --- /dev/null +++ b/common/util/json.go @@ -0,0 +1,103 @@ +package util + +import ( + "encoding/json" + "fmt" + + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/util/gconv" +) + +// ParseOutput 解析模型输出为 JSON 格式 +func ParseOutput(text string) (map[string]any, error) { + j, err := gjson.LoadJson([]byte(text)) + if err != nil { + return nil, fmt.Errorf("解析模型输出失败: %w", err) + } + return j.Map(), nil +} + +// ConvertToMessages 将原始数据转换为消息列表 +func ConvertToMessages(raw any) []map[string]any { + if raw == nil { + return nil + } + j, err := gjson.LoadJson(gconv.Bytes(raw)) + if err != nil { + return nil + } + // 如果有 messages 字段,直接返回 + if j.Contains("messages") { + return gconv.Maps(j.Get("messages").Array()) + } + // 否则当成单条 message + return []map[string]any{ + j.Map(), + } +} + +// IsMessageValid 校验推理结果是否合法 +func IsMessageValid(message map[string]any) bool { + if message == nil { + return false + } + return true +} + +// FormToJSON 将表单数据转换为 JSON 字符串 +func FormToJSON(form map[string]any) string { + if form == nil { + return "{}" + } + b, _ := json.Marshal(form) + return string(b) +} + +// MustMarshal 将对象序列化为 JSON 字符串,失败时返回空对象 +func MustMarshal(v any) string { + b, err := json.Marshal(v) + if err != nil { + return "{}" + } + return string(b) +} + +// ParseJSONField 解析 JSON 字段 +func ParseJSONField(field any) any { + var v *gvar.Var + switch val := field.(type) { + case *gvar.Var: + v = val + default: + return field + } + + if v == nil || v.IsNil() || v.IsEmpty() { + return nil + } + + str := v.String() + var result any + if json.Unmarshal([]byte(str), &result) == nil { + return result + } + return str +} + +// JSONPretty 将任意类型转为格式化的 JSON 字符串 +func JSONPretty(v any) string { + // 处理 *gvar.Var 类型 + if gv, ok := v.(*gvar.Var); ok { + v = gconv.Map(gv.String()) + } + + // 统一转 map 再美化 + var tmp map[string]any + if err := gconv.Struct(v, &tmp); err != nil { + return gconv.String(v) + } + + b, _ := json.MarshalIndent(tmp, "", " ") + return string(b) +} diff --git a/config.yml b/config.yml index fd4b23d..27a9507 100644 --- a/config.yml +++ b/config.yml @@ -26,17 +26,38 @@ database: updatedAt: "updated_at" # (可选)自动更新时间字段名称 deletedAt: "deleted_at" # (可选)软删除时间字段名称 timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效 + model_gateway: + - type: "pgsql" + host: "116.204.74.41" + port: "15432" + user: "postgres" + pass: "Bjang09@686^*^" + name: "model-gateway" + prefix: "" + role: "master" + debug: true + dryRun: false + charset: "utf8" + timezone: "Asia/Shanghai" + maxIdle: 5 + maxOpen: 20 + maxLifetime: "30s" + maxIdleConnTime: "30s" + createdAt: "created_at" + updatedAt: "updated_at" + deletedAt: "deleted_at" + timeMaintainDisabled: false redis: default: - address: 116.204.74.41:6379 + address: 192.168.3.30:6379 db: 0 consul: - address: 116.204.74.41:8500 + address: 192.168.3.30:8500 jaeger: - addr: 116.204.74.41:4318 + addr: 192.168.3.30:4318 task: waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒) diff --git a/consts/public/table_name.go b/consts/public/table_name.go index 933be53..8a01b23 100644 --- a/consts/public/table_name.go +++ b/consts/public/table_name.go @@ -1,8 +1,12 @@ package public const ( - TableNameModel = "asynch_models" // 模型表 - TableNamePromptConfig = "prompts_model_prompt" // 模型提示词配置表(prompts-core) - TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表 - TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表 + DbNameModelGateway = "model_gateway" //数据库名称 +) + +const ( + TableNameModel = "asynch_models" // 模型表 + TableNameComposeTask = "prompts_compose_task" // 拼接提示词任务记录表 + TableNameComposeSession = "prompts_compose_session" // 拼接提示词会话记录表 + TableNameProviderProtocol = "prompts_provider_protocol" ) diff --git a/controller/compose_session.go b/controller/compose_session.go deleted file mode 100644 index 5b3048b..0000000 --- a/controller/compose_session.go +++ /dev/null @@ -1,20 +0,0 @@ -package controller - -import ( - "context" - - "prompts-core/model/dto" - "prompts-core/service" - - "gitea.com/red-future/common/beans" -) - -type session struct{} - -// Prompt 提示词配置控制器 -var Session = new(session) - -// SessionCallback 会话回调 -func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) { - return service.Session.SessionCallback(ctx, req) -} diff --git a/controller/prompt/prompt_compose_controller.go b/controller/prompt/prompt_compose_controller.go new file mode 100644 index 0000000..638a0d8 --- /dev/null +++ b/controller/prompt/prompt_compose_controller.go @@ -0,0 +1,29 @@ +package prompt + +import ( + "context" + + promptDto "prompts-core/model/dto/prompt" + promptService "prompts-core/service/prompt" +) + +type prompt struct{} + +// Prompt 提示词配置控制器 +var Prompt = new(prompt) + +// ComposeMessages 调用 model-gateway 异步任务并同步等待结果, +func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) { + return promptService.ComposeMessages(ctx, req) +} + +// Callback model-gateway 提示词回调 +func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) { + err = promptService.Callback(ctx, req) + return +} + +// GetComposeTask 查询拼接任务结果 +func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) { + return promptService.GetComposeTask(ctx, req.TaskId) +} diff --git a/controller/prompt/prompt_session_controller.go b/controller/prompt/prompt_session_controller.go new file mode 100644 index 0000000..d08b0ce --- /dev/null +++ b/controller/prompt/prompt_session_controller.go @@ -0,0 +1,18 @@ +package prompt + +import ( + "context" + + promptDto "prompts-core/model/dto/prompt" + promptService "prompts-core/service/prompt" +) + +type session struct{} + +// Session 提示词会话控制器 +var Session = new(session) + +// SessionCallback 会话回调 +func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) { + return promptService.SessionCallback(ctx, req) +} diff --git a/controller/prompt_controller.go b/controller/prompt_controller.go deleted file mode 100644 index 00a9c91..0000000 --- a/controller/prompt_controller.go +++ /dev/null @@ -1,69 +0,0 @@ -package controller - -import ( - "context" - - "prompts-core/model/dto" - "prompts-core/service" - - "gitea.com/red-future/common/beans" -) - -type prompt struct{} - -// Prompt 提示词配置控制器 -var Prompt = new(prompt) - -// ComposeMessages 调用 model-gateway 异步任务并同步等待结果, -func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) { - return service.Prompt.ComposeMessages(ctx, req) -} - -// ComposeMessagesCallback model-gateway 提示词回调 -func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *beans.ResponseEmpty, err error) { - err = service.Prompt.Callback(ctx, req) - return -} - -// GetComposeTask 查询拼接任务结果 -func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) { - return service.Prompt.GetComposeTask(ctx, req.TaskId) -} - -// CreatePrompt 添加配置(默认启用) -func (c *prompt) CreatePrompt(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) { - return service.Prompt.Create(ctx, req) -} - -// UpdatePrompt 更新配置 -func (c *prompt) UpdatePrompt(ctx context.Context, req *dto.UpdatePromptReq) (res *beans.ResponseEmpty, err error) { - err = service.Prompt.Update(ctx, req) - return -} - -// DeletePrompt 删除配置 -func (c *prompt) DeletePrompt(ctx context.Context, req *dto.DeletePromptReq) (res *beans.ResponseEmpty, err error) { - err = service.Prompt.Delete(ctx, req.ID) - return -} - -// GetPrompt 获取配置详情 -func (c *prompt) GetPrompt(ctx context.Context, req *dto.GetPromptReq) (res *dto.GetPromptRes, err error) { - m, err := service.Prompt.Get(ctx, req.ID) - if err != nil { - return nil, err - } - return &dto.GetPromptRes{Prompt: m}, nil -} - -// ListPrompt 配置列表 -func (c *prompt) ListPrompt(ctx context.Context, req *dto.ListPromptReq) (res *dto.ListPromptRes, err error) { - list, total, err := service.Prompt.List(ctx, int(req.Page.PageNum), int(req.Page.PageSize), req.ModelTypeId, req.ModelType) - if err != nil { - return nil, err - } - return &dto.ListPromptRes{ - List: list, - Total: total, - }, nil -} diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index 8977603..c7a9787 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -2,82 +2,75 @@ package dao import ( "context" - "prompts-core/consts/public" "prompts-core/model/entity" "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/util/gconv" ) var ComposeSession = &composeSessionDao{} type composeSessionDao struct{} -func (d *composeSessionDao) Insert(ctx context.Context, m *entity.ComposeSession) (id int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession).Data(m).Insert() +// Insert 插入 +func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) { + var m = new(entity.ComposeTask) + err = gconv.Struct(req, &m) if err != nil { - return 0, err + return + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + Insert(m) + if err != nil { + return } return r.LastInsertId() } -func (d *composeSessionDao) Update(ctx context.Context, m *entity.ComposeSession) (rows int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where(entity.ComposeSessionCol.Id, m.Id). - Data(m). +// Update 更新 +func (d *composeSessionDao) Update(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). OmitEmpty(). + Data(&req). + Where(entity.ComposeSessionCol.Id, req.Id). Update() if err != nil { - return 0, err + return } return r.RowsAffected() } -func (d *composeSessionDao) List(ctx context.Context, page, size int, where map[string]any) (list []*entity.ComposeSession, total int, err error) { - model := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where("deleted_at IS NULL") - - // 动态拼接查询条件 - for k, v := range where { - model = model.Where(k, v) +// List 查询编排会话列表 +func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession, page, size int, fields ...string) (list []*entity.ComposeSession, total int, err error) { + if page <= 0 { + page = 1 } - - // 查询总数 - total, err = model.Count() + if size <= 0 { + size = 10 + } + model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + Fields(fields). + OmitEmpty() + model.Where(entity.ComposeSessionCol.Creator, req.Creator) + model.Where(entity.ComposeSessionCol.SessionId, req.SessionId) + model.OrderDesc(entity.ComposeSessionCol.CreatedAt) + model.Page(page, size) + r, total, err := model.AllAndCount(false) if err != nil { - return nil, 0, err + return } - - // 分页查询 - err = model.Order("created_at DESC"). - Page(page, size). - Scan(&list) - + err = r.Structs(&list) return } -func (d *composeSessionDao) GetListBySessionId(ctx context.Context, sessionId string, limit int) ([]*entity.ComposeSession, error) { - var sessions []*entity.ComposeSession - err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where(entity.ComposeSessionCol.SessionId, sessionId). - WhereNull(entity.ComposeSessionCol.DeletedAt). - OrderDesc(entity.ComposeSessionCol.Id). - Limit(limit). - Scan(&sessions) - if err != nil { - return nil, err - } - // 反转成时间正序 - for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 { - sessions[i], sessions[j] = sessions[j], sessions[i] - } - return sessions, nil -} - -func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.ComposeSession, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where(entity.ComposeSessionCol.Id, Id). - One() +// Get 查询编排会话 +func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, fields ...string) (m *entity.ComposeSession, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + OmitEmpty(). + Where(entity.ComposeSessionCol.Id, req.Id). + Where(entity.ComposeSessionCol.SessionId, req.SessionId). + Fields(fields).One() if err != nil { return nil, err } @@ -88,29 +81,15 @@ func (d *composeSessionDao) GetById(ctx context.Context, Id int64) (m *entity.Co return } -func (d *composeSessionDao) GetBySessionId(ctx context.Context, sessionId string) (m *entity.ComposeSession, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where(entity.ComposeSessionCol.SessionId, sessionId). - One() +// Delete 软删除编排会话 +func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). + OmitEmpty(). + Where(entity.ComposeSessionCol.Id, req.Id). + Where(entity.ComposeSessionCol.SessionId, req.SessionId). + Delete() if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -func (d *composeSessionDao) DeleteBySessionId(ctx context.Context, sessionId string) (rows int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeSession). - Where(entity.ComposeSessionCol.SessionId, sessionId). - Data(map[string]any{ - entity.ComposeSessionCol.DeletedAt: "NOW()", - }). - Update() - if err != nil { - return 0, err + return } return r.RowsAffected() } diff --git a/dao/compose_task_dao.go b/dao/compose_task_dao.go index a52101f..7e16018 100644 --- a/dao/compose_task_dao.go +++ b/dao/compose_task_dao.go @@ -2,47 +2,54 @@ package dao import ( "context" - "prompts-core/consts/public" "prompts-core/model/entity" "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/util/gconv" ) var ComposeTask = &composeTaskDao{} type composeTaskDao struct{} -func (d *composeTaskDao) Insert(ctx context.Context, m *entity.ComposeTask) (id int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask).Data(m).Insert() +// Insert 插入 +func (d *composeTaskDao) Insert(ctx context.Context, req *entity.ComposeTask) (id int64, err error) { + var m = new(entity.ComposeTask) + err = gconv.Struct(req, &m) if err != nil { - return 0, err + return + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask). + Insert(m) + if err != nil { + return } return r.LastInsertId() } -func (d *composeTaskDao) GetByTaskId(ctx context.Context, taskId string) (m *entity.ComposeTask, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask). - Where(entity.ComposeTaskCol.TaskId, taskId). - One() +// Get 获取 +func (d *composeTaskDao) Get(ctx context.Context, req *entity.ComposeTask, fields ...string) (m *entity.ComposeTask, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask). + OmitEmpty(). + Where(entity.ComposeTaskCol.TaskId, req.TaskId). + Fields(fields).One() if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil + return } err = r.Struct(&m) return } -func (d *composeTaskDao) UpdateByTaskId(ctx context.Context, taskId string, data map[string]any) (rows int64, err error) { - data[entity.ComposeTaskCol.Updater] = "" - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameComposeTask). - Where(entity.ComposeTaskCol.TaskId, taskId). - Data(data). +// Update 更新 +func (d *composeTaskDao) Update(ctx context.Context, req *entity.ComposeTask) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeTask). + OmitEmpty(). + Data(&req). + Where(entity.ComposeTaskCol.TaskId, req.TaskId). Update() if err != nil { - return 0, err + return } return r.RowsAffected() } diff --git a/dao/model_dao.go b/dao/model_dao.go index 6c8da8c..42cd0a9 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -2,62 +2,27 @@ package dao import ( "context" - "fmt" "prompts-core/consts/public" "prompts-core/model/entity" "gitea.com/red-future/common/db/gfdb" - "gitea.com/red-future/common/utils" ) var Model = &modelDao{} type modelDao struct{} -func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - Where(entity.AsynchModelCol.ModelName, modelName). - One() +// Get 获取模型 +func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + OmitEmpty(). + Where(entity.AsynchModelCol.Creator, req.Creator). + Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). + Where(entity.AsynchModelCol.ModelName, req.ModelName). + Fields(fields).One() if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil + return } err = r.Struct(&m) return } - -func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) { - userInfo, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - Where(entity.AsynchModelCol.IsChatModel, 1). - Where(entity.AsynchModelCol.Creator, userInfo.UserName). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -// GetBySuperAdmin 查询超级管理员(tenant_id=1)的模型 -func (d *modelDao) GetBySuperAdmin(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) { - sql := fmt.Sprintf("SELECT * FROM %s WHERE model_name = ? AND tenant_id = 1 AND deleted_at IS NULL LIMIT 1", public.TableNameModel) - r, err := gfdb.DB(ctx).GetAll(ctx, sql, modelName) - if err != nil { - return nil, err - } - if len(r) == 0 { - return nil, nil - } - - err = r[0].Struct(&m) - return -} diff --git a/dao/prompt_dao.go b/dao/prompt_dao.go deleted file mode 100644 index aa5a884..0000000 --- a/dao/prompt_dao.go +++ /dev/null @@ -1,97 +0,0 @@ -package dao - -import ( - "context" - - "prompts-core/consts/public" - "prompts-core/model/entity" - - "gitea.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/util/gconv" -) - -var Prompt = &promptDao{} - -type promptDao struct{} - -func (d *promptDao) Insert(ctx context.Context, m *entity.PromptConfig) (id int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Data(m).Insert() - if err != nil { - return 0, err - } - return r.LastInsertId() -} - -func (d *promptDao) UpdateByID(ctx context.Context, id int64, data map[string]any) (rows int64, err error) { - // 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段 - data[entity.PromptConfigCol.Updater] = "" - r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig). - Where(entity.PromptConfigCol.Id, id). - Data(data). - Update() - if err != nil { - return 0, err - } - return r.RowsAffected() -} - -func (d *promptDao) DeleteByID(ctx context.Context, id int64) (rows int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig). - Where(entity.PromptConfigCol.Id, id). - Delete() - if err != nil { - return 0, err - } - return r.RowsAffected() -} - -func (d *promptDao) GetByID(ctx context.Context, id int64) (m *entity.PromptConfig, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig). - Where(entity.PromptConfigCol.Id, id). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -func (d *promptDao) GetLatestEnabledByModelTypeID(ctx context.Context, modelTypeID int) (m *entity.PromptConfig, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig). - Where("deleted_at IS NULL"). - Where(entity.PromptConfigCol.ModelTypeId, modelTypeID). - Where(entity.PromptConfigCol.Enabled, 1). - OrderDesc(entity.PromptConfigCol.CreatedAt). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -func (d *promptDao) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) { - model := gfdb.DB(ctx).Model(ctx, public.TableNamePromptConfig).Where("deleted_at IS NULL").OrderDesc(entity.PromptConfigCol.CreatedAt) - if modelTypeID != nil && *modelTypeID > 0 { - model = model.Where(entity.PromptConfigCol.ModelTypeId, *modelTypeID) - } - if modelTypeLike != "" { - model = model.WhereLike(entity.PromptConfigCol.ModelType, "%"+modelTypeLike+"%") - } - if pageNum > 0 && pageSize > 0 { - model = model.Page(pageNum, pageSize) - } - r, totalInt, err := model.AllAndCount(false) - if err != nil { - return nil, 0, err - } - total = gconv.Int64(totalInt) - err = r.Structs(&list) - return -} diff --git a/dao/provider_protocol_dao.go b/dao/provider_protocol_dao.go new file mode 100644 index 0000000..d4808db --- /dev/null +++ b/dao/provider_protocol_dao.go @@ -0,0 +1,91 @@ +package dao + +import ( + "context" + "prompts-core/consts/public" + "prompts-core/model/entity" + + "gitea.com/red-future/common/db/gfdb" +) + +var ProviderProtocol = &providerProtocolDao{} + +type providerProtocolDao struct{} + +// Insert 新增协议配置 +func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert() + if err != nil { + return 0, err + } + return r.LastInsertId() +} + +// Get 查询协议配置 +func (d *providerProtocolDao) Get(ctx context.Context, req *entity.ProviderProtocol, fields ...string) (res *entity.ProviderProtocol, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). + NoTenantId(ctx). + OmitEmpty(). + Where(entity.ProviderProtocolCol.Id, req.Id). + Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName). //主要是根据运营商查询 + Where(entity.ProviderProtocolCol.Status, 1). + Fields(fields).One() + if err != nil { + return nil, err + } + if r.IsEmpty() { + return nil, nil + } + err = r.Struct(&res) + return +} + +// List 列表查询 +func (d *providerProtocolDao) List(ctx context.Context, req *entity.ProviderProtocol, page, size int, fields ...string) (list []*entity.ProviderProtocol, total int, err error) { + if page <= 0 { + page = 1 + } + if size <= 0 { + size = 10 + } + model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). + Fields(fields). + OmitEmpty() + model.Where(entity.ProviderProtocolCol.ProviderName, req.ProviderName) + model.Where(entity.ProviderProtocolCol.Status, req.Status) + model.OrderDesc(entity.ProviderProtocolCol.CreatedAt) + model.Page(page, size) + r, total, err := model.AllAndCount(false) + if err != nil { + return + } + err = r.Structs(&list) + return +} + +// Update 更新协议配置 +func (d *providerProtocolDao) Update(ctx context.Context, req *entity.ProviderProtocol) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). + OmitEmpty(). + Where(entity.ProviderProtocolCol.Id, req.Id). + Data(req). + Update() + if err != nil { + return 0, err + } + return r.RowsAffected() +} + +// Delete 软删除协议配置 +func (d *providerProtocolDao) Delete(ctx context.Context, id int64) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). + Where(entity.ProviderProtocolCol.Id, id). + Data(map[string]any{ + entity.ProviderProtocolCol.DeletedAt: "NOW()", + }). + Update() + if err != nil { + return 0, err + } + return r.RowsAffected() +} diff --git a/main.go b/main.go index 1c49d22..4c5ac6b 100644 --- a/main.go +++ b/main.go @@ -4,10 +4,9 @@ import ( "context" "os" "os/signal" + "prompts-core/controller/prompt" "syscall" - "prompts-core/controller" - "gitea.com/red-future/common/http" "gitea.com/red-future/common/jaeger" _ "gitea.com/red-future/common/swagger" @@ -20,14 +19,13 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() defer jaeger.ShutDown(ctx) - // 注册路由 http.RouteRegister([]interface{}{ - controller.Prompt, - controller.Session, + prompt.Prompt, + prompt.Session, }) - // 监听退出信号,确保 Ctrl+C 能完整退出并关闭 http server + // 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server quit := make(chan os.Signal, 1) signal.Notify(quit, os.Interrupt, syscall.SIGTERM) <-quit diff --git a/model/dto/compose_messages_dto.go b/model/dto/prompt/prompt_compose_dto.go similarity index 94% rename from model/dto/compose_messages_dto.go rename to model/dto/prompt/prompt_compose_dto.go index 2288b43..2e0b211 100644 --- a/model/dto/compose_messages_dto.go +++ b/model/dto/prompt/prompt_compose_dto.go @@ -1,12 +1,7 @@ -package dto +package prompt import "github.com/gogf/gf/v2/frame/g" -type Message struct { - Role string `json:"role" dc:"角色:system/user/assistant"` - Content any `json:"content" dc:"消息内容"` -} - type ComposeMessagesReq struct { g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"` ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` @@ -35,6 +30,9 @@ type CallbackReq struct { EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` } +type CallbackRes struct { +} + type GetComposeTaskReq struct { g.Meta `path:"/getComposeTask" method:"get" tags:"提示词处理" summary:"查询拼接任务" dc:"按 taskId 查询提示词拼接任务结果"` TaskId string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"` diff --git a/model/dto/compose_session_dto.go b/model/dto/prompt/prompt_session_dto.go similarity index 83% rename from model/dto/compose_session_dto.go rename to model/dto/prompt/prompt_session_dto.go index c9bcc61..5901ed7 100644 --- a/model/dto/compose_session_dto.go +++ b/model/dto/prompt/prompt_session_dto.go @@ -1,4 +1,4 @@ -package dto +package prompt import "github.com/gogf/gf/v2/frame/g" @@ -7,3 +7,6 @@ type SessionCallbackReq struct { Text string `json:"text" dc:"文本结果"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` } + +type SessionCallbackRes struct { +} diff --git a/model/dto/prompt_dto.go b/model/dto/prompt_dto.go deleted file mode 100644 index 07eb5d1..0000000 --- a/model/dto/prompt_dto.go +++ /dev/null @@ -1,63 +0,0 @@ -package dto - -import ( - "gitea.com/red-future/common/beans" - "github.com/gogf/gf/v2/frame/g" -) - -// CreatePromptReq 添加提示词配置(默认启用) -type CreatePromptReq struct { - g.Meta `path:"/createPrompt" method:"post" tags:"提示词管理" summary:"创建提示词配置" dc:"创建新的模型提示词配置(默认启用)"` - ModelTypeId int `p:"modelTypeId" json:"modelTypeId" v:"required#modelTypeId不能为空" dc:"模型分类ID"` - ModelType string `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类别/模型类型"` - PromptInfo any `p:"promptInfo" json:"promptInfo" v:"required#promptInfo不能为空" dc:"数据库定义的表单规则数据(JSON)"` - ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" v:"required#responseJsonSchema不能为空" dc:"模型返回表单 JSON 格式约束"` - // Version 预留字段:先不使用,但表结构保留 - Version string `p:"version" json:"version" dc:"版本号(预留)"` -} - -type CreatePromptRes struct { - ID int64 `json:"id,string" dc:"配置ID"` -} - -// UpdatePromptReq 更新提示词配置 -type UpdatePromptReq struct { - g.Meta `path:"/updatePrompt" method:"put" tags:"提示词管理" summary:"更新提示词配置" dc:"更新指定ID的提示词配置"` - ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` - - ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID(可选更新)"` - ModelType *string `p:"modelType" json:"modelType" dc:"模型类别/模型类型(可选更新)"` - PromptInfo any `p:"promptInfo" json:"promptInfo" dc:"数据库定义的表单规则数据(JSON)(可选更新)"` - ResponseJsonSchema any `p:"responseJsonSchema" json:"responseJsonSchema" dc:"模型返回表单 JSON 格式约束(可选更新)"` - Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` - Version *string `p:"version" json:"version" dc:"版本号(预留,可选更新)"` -} - -// DeletePromptReq 删除提示词配置 -type DeletePromptReq struct { - g.Meta `path:"/deletePrompt" method:"delete" tags:"提示词管理" summary:"删除提示词配置" dc:"删除指定ID的提示词配置"` - ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` -} - -// GetPromptReq 获取提示词配置详情 -type GetPromptReq struct { - g.Meta `path:"/getPrompt" method:"get" tags:"提示词管理" summary:"获取提示词配置" dc:"根据ID获取提示词配置详情"` - ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` -} - -type GetPromptRes struct { - Prompt any `json:"prompt" dc:"提示词配置详情"` -} - -// ListPromptReq 配置列表 -type ListPromptReq struct { - g.Meta `path:"/listPrompt" method:"post" tags:"提示词管理" summary:"提示词配置列表" dc:"分页获取提示词配置列表"` - Page *beans.Page `p:"page" json:"page" dc:"分页参数"` - ModelTypeId *int `p:"modelTypeId" json:"modelTypeId" dc:"模型分类ID(可选)"` - ModelType string `p:"modelType" json:"modelType" dc:"模型类型名称(可选,模糊查询)"` -} - -type ListPromptRes struct { - List any `json:"list" dc:"列表数据"` - Total int64 `json:"total" dc:"总数"` -} diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index 9714cb9..c704254 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -2,6 +2,34 @@ package entity import "gitea.com/red-future/common/beans" +// AsynchModel 异步模型配置 +type AsynchModel struct { + beans.SQLBaseDO `orm:",inline"` + ModelName string `orm:"model_name" json:"modelName"` + ModelType int `orm:"model_type" json:"modelType"` + BaseURL string `orm:"base_url" json:"baseUrl"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + HeadMsg string `orm:"head_msg" json:"headMsg"` + Form any `orm:"form_json" json:"form"` + RequestMapping any `orm:"request_mapping" json:"requestMapping"` + ResponseMapping any `orm:"response_mapping" json:"responseMapping"` + ResponseBody any `orm:"response_body" json:"responseBody"` + TokenMapping string `orm:"token_mapping" json:"tokenMapping"` + Prompt string `orm:"prompt" json:"prompt"` + IsPrivate int `orm:"is_private" json:"isPrivate"` + IsChatModel int `orm:"is_chat_model" json:"isChatModel"` + ApiKey string `orm:"api_key" json:"apiKey"` + Enabled int `orm:"enabled" json:"enabled"` + MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` + QueueLimit int `orm:"queue_limit" json:"queueLimit"` + TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` + ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` + RetryTimes int `orm:"retry_times" json:"retryTimes"` + RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` + AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` + Remark string `orm:"remark" json:"remark"` +} + type asynchModelCol struct { beans.SQLBaseCol ModelName string @@ -55,31 +83,3 @@ var AsynchModelCol = asynchModelCol{ AutoCleanSeconds: "auto_clean_seconds", Remark: "remark", } - -// AsynchModel 异步模型配置 -type AsynchModel struct { - beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - ModelType int `orm:"model_type" json:"modelType"` - BaseURL string `orm:"base_url" json:"baseUrl"` - HttpMethod string `orm:"http_method" json:"httpMethod"` - HeadMsg string `orm:"head_msg" json:"headMsg"` - Form any `orm:"form_json" json:"form"` - RequestMapping any `orm:"request_mapping" json:"requestMapping"` - ResponseMapping any `orm:"response_mapping" json:"responseMapping"` - ResponseBody any `orm:"response_body" json:"responseBody"` - TokenMapping string `orm:"token_mapping" json:"tokenMapping"` - Prompt string `orm:"prompt" json:"prompt"` - IsPrivate int `orm:"is_private" json:"isPrivate"` - IsChatModel int `orm:"is_chat_model" json:"isChatModel"` - ApiKey string `orm:"api_key" json:"apiKey"` - Enabled int `orm:"enabled" json:"enabled"` - MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` - QueueLimit int `orm:"queue_limit" json:"queueLimit"` - TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` - ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` - RetryTimes int `orm:"retry_times" json:"retryTimes"` - RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` - AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` - Remark string `orm:"remark" json:"remark"` -} diff --git a/model/entity/prompt_config.go b/model/entity/prompt_config.go deleted file mode 100644 index 7a16dc8..0000000 --- a/model/entity/prompt_config.go +++ /dev/null @@ -1,39 +0,0 @@ -package entity - -import "gitea.com/red-future/common/beans" - -type promptConfigCol struct { - beans.SQLBaseCol - ModelTypeId string - ModelType string - PromptInfo string - ResponseJsonSchema string - Enabled string - Version string -} - -var PromptConfigCol = promptConfigCol{ - SQLBaseCol: beans.DefSQLBaseCol, - ModelTypeId: "model_type_id", - ModelType: "model_type", - PromptInfo: "prompt_info", - ResponseJsonSchema: "response_json_schema", - Enabled: "enabled", - Version: "version", -} - -// PromptConfig 模型提示词配置 -// -// 说明: -// - prompt_info 使用 JSONB 保存(对外用 json 传输) -// - response_json_schema 为模型返回 JSON 格式约束 -// - enabled:1启用/0禁用 -type PromptConfig struct { - beans.SQLBaseDO `orm:",inline"` - ModelTypeId int `orm:"model_type_id" json:"modelTypeId"` - ModelType string `orm:"model_type" json:"modelType"` - PromptInfo any `orm:"prompt_info" json:"promptInfo"` - ResponseJsonSchema any `orm:"response_json_schema" json:"responseJsonSchema"` - Enabled int `orm:"enabled" json:"enabled"` - Version string `orm:"version" json:"version"` -} diff --git a/model/entity/compose_session.go b/model/entity/prompts_compose_session.go similarity index 100% rename from model/entity/compose_session.go rename to model/entity/prompts_compose_session.go index 7f9579f..9dcb38a 100644 --- a/model/entity/compose_session.go +++ b/model/entity/prompts_compose_session.go @@ -2,6 +2,14 @@ package entity import "gitea.com/red-future/common/beans" +type ComposeSession struct { + beans.SQLBaseDO `orm:",inline"` + SessionId string `orm:"session_id" json:"sessionId"` + RequestContent any `orm:"request_content" json:"requestContent"` + ResponseContent any `orm:"response_content" json:"responseContent"` + Remark string `orm:"remark" json:"remark"` +} + type composeSessionCol struct { beans.SQLBaseCol SessionId string @@ -17,11 +25,3 @@ var ComposeSessionCol = composeSessionCol{ ResponseContent: "response_content", Remark: "remark", } - -type ComposeSession struct { - beans.SQLBaseDO `orm:",inline"` - SessionId string `orm:"session_id" json:"sessionId"` - RequestContent any `orm:"request_content" json:"requestContent"` - ResponseContent any `orm:"response_content" json:"responseContent"` - Remark string `orm:"remark" json:"remark"` -} diff --git a/model/entity/compose_task.go b/model/entity/prompts_compose_task.go similarity index 100% rename from model/entity/compose_task.go rename to model/entity/prompts_compose_task.go index 780bd54..b81715a 100644 --- a/model/entity/compose_task.go +++ b/model/entity/prompts_compose_task.go @@ -2,6 +2,20 @@ package entity import "gitea.com/red-future/common/beans" +type ComposeTask struct { + beans.SQLBaseDO `orm:",inline"` + TaskId string `orm:"task_id" json:"taskId"` + ModelName string `orm:"model_name" json:"modelName"` + SkillName string `orm:"skill_name" json:"skillName"` + LimitWords int `orm:"limit_words" json:"limitWords"` + RequestPayload any `orm:"request_payload" json:"requestPayload"` + CallbackPayload any `orm:"callback_payload" json:"callbackPayload"` + ModelResult any `orm:"model_result" json:"modelResult"` + Messages any `orm:"messages" json:"messages"` + Status string `orm:"status" json:"status"` + ErrorMessage string `orm:"error_message" json:"errorMessage"` +} + type composeTaskCol struct { beans.SQLBaseCol TaskId string @@ -29,17 +43,3 @@ var ComposeTaskCol = composeTaskCol{ Status: "status", ErrorMessage: "error_message", } - -type ComposeTask struct { - beans.SQLBaseDO `orm:",inline"` - TaskId string `orm:"task_id" json:"taskId"` - ModelName string `orm:"model_name" json:"modelName"` - SkillName string `orm:"skill_name" json:"skillName"` - LimitWords int `orm:"limit_words" json:"limitWords"` - RequestPayload any `orm:"request_payload" json:"requestPayload"` - CallbackPayload any `orm:"callback_payload" json:"callbackPayload"` - ModelResult any `orm:"model_result" json:"modelResult"` - Messages any `orm:"messages" json:"messages"` - Status string `orm:"status" json:"status"` - ErrorMessage string `orm:"error_message" json:"errorMessage"` -} diff --git a/model/entity/prompts_provider_protocol.go b/model/entity/prompts_provider_protocol.go new file mode 100644 index 0000000..bbf4488 --- /dev/null +++ b/model/entity/prompts_provider_protocol.go @@ -0,0 +1,49 @@ +package entity + +import "gitea.com/red-future/common/beans" + +// ProviderProtocol 模型协议映射配置 +type ProviderProtocol struct { + beans.SQLBaseDO `orm:",inherit"` + // 业务字段 + ProviderName string `orm:"provider_name" json:"providerName"` + TargetField string `orm:"target_field" json:"targetField"` + MergeOrder any `orm:"merge_order" json:"mergeOrder"` + RoleMapping any `orm:"role_mapping" json:"roleMapping"` + ContentMapping any `orm:"content_mapping" json:"contentMapping"` + Capabilities any `orm:"capabilities" json:"capabilities"` + RequestTemplate any `orm:"request_template" json:"requestTemplate"` + SystemPromptTemplate string `orm:"system_prompt_template" json:"systemPromptTemplate"` + Status int `orm:"status" json:"status"` + Remark string `orm:"remark" json:"remark"` +} + +// providerProtocolCol 列名 +type providerProtocolCol struct { + beans.SQLBaseCol + ProviderName string + TargetField string + MergeOrder string + RoleMapping string + ContentMapping string + Capabilities string + RequestTemplate string + SystemPromptTemplate string + Status string + Remark string +} + +// ProviderProtocolCol 列名常量 +var ProviderProtocolCol = providerProtocolCol{ + SQLBaseCol: beans.DefSQLBaseCol, + ProviderName: "provider_name", + TargetField: "target_field", + MergeOrder: "merge_order", + RoleMapping: "role_mapping", + ContentMapping: "content_mapping", + Capabilities: "capabilities", + RequestTemplate: "request_template", + SystemPromptTemplate: "system_prompt_template", + Status: "status", + Remark: "remark", +} diff --git a/service/build_prompt.go b/service/build_prompt.go deleted file mode 100644 index 4a4e32d..0000000 --- a/service/build_prompt.go +++ /dev/null @@ -1,144 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "prompts-core/model/dto" - "prompts-core/model/entity" - "strings" - - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/util/gconv" -) - -// 获取请求模型的提示词 -func GetModelPrompt(ctx context.Context, Type int) string { - return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String() -} - -// 获取构建提示词 -func GetBuildPrompt(ctx context.Context, Type int) string { - return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String() -} - -// buildInferenceRequest 构建返回请求 -func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) { - messages := []map[string]any{} - switch req.BuildType { - //构建提示词请求 - case 1: - //1. 构建系统提示词 - messages = append(messages, map[string]any{ - "role": "system", - "content": promptBuild(ctx, req, model), - }) - // 2. 构建历史会话提示词 - for _, msg := range history { - role := gconv.String(msg["role"]) - content := gconv.String(msg["content"]) - if role != "user" && role != "assistant" { - continue - } - messages = append(messages, map[string]any{ - "role": role, - "content": content, - }) - } - // 3. 当前用户问题(原来的最后一条) - messages = append(messages, map[string]any{ - "role": "user", - "content": buildUserPrompt(ctx, req, GetModelPrompt(ctx, model.ModelType)), - }) - //构建节点请求 - case 2: - messages = append(messages, map[string]any{ - "role": "user", - "content": NodeBuid(ctx, req), - }) - default: - return nil, errors.New("不支持的构建类型") - } - // 构建请求体 - return map[string]any{ - "modelName": chatModel.ModelName, - "bizName": "prompts-core", - "callbackUrl": "/prompt/callback", - "requestPayload": map[string]any{ - "model": chatModel.ModelName, - "messages": messages, - "stream": false, - }, - }, nil -} - -// ============================================ -// 构建用户提示词 -// ============================================ -func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string { - payload := map[string]any{ - "model": req.ModelName, - //数据库提示信息 - "promptInfo": prompt, - // 系统表单 - "form": req.Form, - // 用户表单 - "userForm": req.UserForm, - //文件url - "userFiles": req.UserFiles, - //解读文件(只支持可读类型 如:xml,json,yaml) - "userFilesText": FetchFileTexts(ctx, req.UserFiles), - //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容) - "skills": SkillMdContent(ctx, req.SkillName), - } - return mustMarshal(payload) -} - -// promptBuild 提示词构建 -func promptBuild(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) string { - // 1. 从配置文件读取提示词模板 - promptTpl := GetBuildPrompt(ctx, req.BuildType) - if promptTpl == "" { - return "" - } - // 2. 构建字段映射说明 - mappingBytes, _ := json.Marshal(model.RequestMapping) - mappingStr := string(mappingBytes) - - var mapping map[string]string - _ = json.Unmarshal(mappingBytes, &mapping) - - var fieldDesc strings.Builder - for key, path := range mapping { - fieldDesc.WriteString(fmt.Sprintf("- %s → %s\n", key, path)) - } - - // 3. 拼接 UserForm 全文(必须完整阅读) - var userFormContent strings.Builder - for k, v := range req.UserForm { - userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v)) - } - userFormFullText := strings.TrimSuffix(userFormContent.String(), ";") - - // 4. 双表单信息 - formInfo := fmt.Sprintf(` -【系统表单(系统提示词/参数)】 -%s -【用户表单全文(必须完整阅读,全部作为用户提示词来源)】 -%s -`, formToJSON(req.Form), userFormFullText) - // 5. 格式化最终提示词(替换配置里的 %s) - return fmt.Sprintf(promptTpl, mappingStr, fieldDesc.String(), formInfo) -} - -// NodeBuid 节点构建 -func NodeBuid(ctx context.Context, req *dto.ComposeMessagesReq) string { - promptTpl := GetBuildPrompt(ctx, req.BuildType) - if promptTpl == "" { - return "" - } - formStr := formToJSON(req.Form) - userFormStr := formToJSON(req.UserForm) - return fmt.Sprintf(promptTpl, formStr, userFormStr) -} diff --git a/service/http_service.go b/service/gateway/gateway_http_service.go similarity index 75% rename from service/http_service.go rename to service/gateway/gateway_http_service.go index 9109937..9c35885 100644 --- a/service/http_service.go +++ b/service/gateway/gateway_http_service.go @@ -1,9 +1,10 @@ -package service +package gateway import ( "context" "encoding/json" "fmt" + "prompts-core/common/util" commonHttp "gitea.com/red-future/common/http" "github.com/gogf/gf/v2/os/gtime" @@ -19,10 +20,10 @@ type CreateTaskReq struct { ErrorMsg string `json:"error_msg"` } -// createGatewayTask 调用 model-gateway 异步任务并同步等待结果 -func createGatewayTask(ctx context.Context, payload map[string]any) (string, error) { +// CreateGatewayTask 创建网关异步任务 +func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) { fullURL := "model-gateway/task/createTask" - headers := forwardHeaders(ctx) + headers := util.ForwardHeaders(ctx) var req CreateTaskReq body, err := json.Marshal(payload) if err != nil { @@ -34,15 +35,16 @@ func createGatewayTask(ctx context.Context, payload map[string]any) (string, err return req.TaskId, nil } +// GetTaskResultRes 任务结果响应 type GetTaskResultRes struct { OssFile string `json:"ossFile" dc:"结果文件OSS地址"` State int `json:"state" dc:"任务状态"` } -// queryGatewayTaskState 查询网关任务状态 -func queryGatewayTaskState(ctx context.Context, taskID string) (int, error) { +// QueryGatewayTaskState 查询网关任务状态 +func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) { fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID) - headers := forwardHeaders(ctx) + headers := util.ForwardHeaders(ctx) var req GetTaskResultRes if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil { return 0, err @@ -56,16 +58,16 @@ type SkillUserVO struct { Name string `json:"name"` Description string `json:"description"` FileName string `json:"fileName"` - FileUrl string `json:"fileUrl"` // html 后缀 + FileUrl string `json:"fileUrl"` CreatedAt *gtime.Time `json:"createdAt"` UpdatedAt *gtime.Time `json:"updatedAt"` - ImgAddressPrefix string `json:"imgAddressPrefix"` // htmml 前缀 + ImgAddressPrefix string `json:"imgAddressPrefix"` } -// GetSkillUser 根据 name 获取技能用户信息 +// GetSkillUser 获取技能用户信息 func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name) - headers := forwardHeaders(ctx) + headers := util.ForwardHeaders(ctx) var resp SkillUserVO var req struct{} if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil { diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go new file mode 100644 index 0000000..412556d --- /dev/null +++ b/service/prompt/prompt_build_service.go @@ -0,0 +1,112 @@ +package prompt + +import ( + "context" + "errors" + "fmt" + "strings" + + "prompts-core/common/util" + "prompts-core/dao" + "prompts-core/model/dto/prompt" + "prompts-core/model/entity" + + "github.com/gogf/gf/v2/util/gconv" +) + +// buildInferenceRequest 构建返回请求 +func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) { + ir := NewPromptIR() + // 1. 统一 Prompt IR + switch req.BuildType { + case 1: //构建提示词请求 + ir.AddSystem(promptBuild(ctx, req, model)) + for _, msg := range history { + role := gconv.String(msg["role"]) + if role != "user" && role != "assistant" { + continue + } + ir.AddHistory(role, gconv.String(msg["content"])) + } + ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType))) + case 2: //构建节点请求 + ir.AddUser(NodeBuild(ctx, req)) + default: + return nil, errors.New("不支持的构建类型") + } + + // 2. 获取协议配置 + protocol, err := GetProtocolByProvider(ctx, "qwen") + if err != nil { + return nil, err + } + if protocol == nil { + return nil, errors.New("协议配置不存在") + } + + // 3. 编译为 Provider Request + providerReq, err := Compile(ir, protocol, chatModel) + if err != nil { + return nil, err + } + + // 4. 构建请求体 + return map[string]any{ + "modelName": chatModel.ModelName, + "bizName": "prompts-core", + "callbackUrl": "/prompt/callback", + "requestPayload": providerReq, + }, nil +} + +// promptBuild 构建系统提示词 +func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string { + providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ + ProviderName: "qwen", + Status: 1, + }) + if err != nil || providerProtocol == nil { + return "" + } + + outputJSON := util.JSONPretty(model.RequestMapping) + var userFormContent strings.Builder + for k, v := range req.UserForm { + userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v)) + } + userFormFullText := strings.TrimSuffix(userFormContent.String(), ";") + + formInfo := fmt.Sprintf(` +【系统表单(系统提示词/参数)】 +%s +【用户表单全文(必须完整阅读,全部作为用户提示词来源)】 +%s +`, util.FormToJSON(req.Form), userFormFullText) + + return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo) +} + +// 构建用户提示词 +func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string { + payload := map[string]any{ + "model": req.ModelName, // 请求模型名称 + "promptInfo": prompt, // 数据库提示信息 + "form": req.Form, // 系统表单 + "userForm": req.UserForm, // 用户表单 + "userFiles": req.UserFiles, //文件url + "userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如:xml,json,yaml) + "skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容) + } + return util.MustMarshal(payload) +} + +// NodeBuild 节点构建 +func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string { + promptTpl := util.GetBuildPrompt(ctx, req.BuildType) + if promptTpl == "" { + return "" + } + formStr := util.FormToJSON(req.Form) + userFormStr := util.FormToJSON(req.UserForm) + return fmt.Sprintf(promptTpl, formStr, userFormStr) +} diff --git a/service/compose_service.go b/service/prompt/prompt_compose_service.go similarity index 69% rename from service/compose_service.go rename to service/prompt/prompt_compose_service.go index 9fd7779..8f8a354 100644 --- a/service/compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -1,28 +1,28 @@ -package service +package prompt import ( "context" "encoding/json" "errors" "fmt" + "prompts-core/dao" + "prompts-core/model/entity" "strings" "time" + "prompts-core/common/util" "prompts-core/consts/public" - "prompts-core/dao" - "prompts-core/model/dto" - "prompts-core/model/entity" + promptDto "prompts-core/model/dto/prompt" + "prompts-core/service/gateway" + "gitea.com/red-future/common/beans" + "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" ) -// ============================================ -// 核心业务流程 -// ============================================ - -// ComposeMessages 拼接提示词主流程 -func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { +// ComposeMessages 核心拼接提示词主流程 +func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) { var ( epicycleId int64 taskID string @@ -32,7 +32,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes taskRecord *entity.ComposeTask ) // 获取模型信息 - chatModel, model, err := s.GetModelMessage(ctx, req) + chatModel, aiModel, err := GetModelMessage(ctx, req) if err != nil { return nil, err } @@ -42,18 +42,18 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes case 1: maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() //1. 获取历史会话 - history, err = Session.GetHistoryMessages(ctx, req.SessionId) + history, err = GetHistoryMessages(ctx, req.SessionId) if err != nil { g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) history = nil // 出错就用空的,不影响主流程 } // 重试循环 - for attempt := 0; attempt <= maxRetryTimes; attempt++ { + for attempt := 0; attempt <= 0; attempt++ { if attempt > 0 { g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) } // 2. 调用推理模型 - taskID, err = s.callInferenceModel(ctx, req, chatModel, model, history) + taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history) if err != nil { g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) continue @@ -64,7 +64,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, - RequestPayload: mustMarshal(req), + RequestPayload: util.MustMarshal(req), Status: public.ComposeStatusPending, }) if err != nil { @@ -73,14 +73,14 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes } // 4. 等待结果 - taskRecord, err = s.waitForResult(ctx, taskID) + taskRecord, err = waitForResult(ctx, taskID) if err != nil { g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) continue } // 校验结果 - message = s.parsePromptBuild(taskRecord, chatModel) - if message != nil && isMessageValid(message) { + message = parsePromptBuild(taskRecord, chatModel) + if message != nil && util.IsMessageValid(message) { break } g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) @@ -97,7 +97,7 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes //节点构建 case 2: //1. 调用推理模型 - taskID, err = s.callInferenceModel(ctx, req, chatModel, model, nil) + taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil) if err != nil { return nil, err } @@ -106,115 +106,41 @@ func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMes TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, - RequestPayload: mustMarshal(req), + RequestPayload: util.MustMarshal(req), Status: public.ComposeStatusPending, }) //5. 等待结果 - taskRecord, err := s.waitForResult(ctx, taskID) + taskRecord, err := waitForResult(ctx, taskID) if err != nil { return nil, err } - fmt.Println("构建节点前", taskRecord) - message = s.parseNodeBuild(taskRecord) - fmt.Println("构建节点后", message) + message = parseNodeBuild(taskRecord) default: epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ SessionId: req.SessionId, Remark: req.Cause, }) - return &dto.ComposeMessagesRes{ + return &promptDto.ComposeMessagesRes{ EpicycleId: epicycleId, }, nil } - return &dto.ComposeMessagesRes{ + return &promptDto.ComposeMessagesRes{ Messages: message, EpicycleId: epicycleId, }, nil } -func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error { - g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", - req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) - - // ============ 先查任务是否存在 ============ - task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId) - if err != nil { - return err - } - if task == nil { - return fmt.Errorf("任务不存在: %s", req.TaskId) - } - // ============ 根据状态区分处理 ============ - if req.State == 3 { - // 失败:直接更新状态 - _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ - entity.ComposeTaskCol.Status: public.ComposeStatusFailed, - entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg, - }) - return err - } - // ====================================== - // 成功:解析模型输出 - result, err := parseOutput(req.Text) - if err != nil { - _, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ - entity.ComposeTaskCol.Status: public.ComposeStatusFailed, - entity.ComposeTaskCol.ErrorMessage: err.Error(), - }) - if updateErr != nil { - g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr) - } - return err - } - - // ============ result 可能为 nil ============ - var messages any - if result != nil { - messages = result - } - // ======================================= - - _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ - entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, - entity.ComposeTaskCol.Messages: messages, - }) - if err != nil { - g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) - } - return err -} - -// GetComposeTask 查询任务结果 -func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { - record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) - if err != nil { - return nil, err - } - if record == nil { - return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) - } - - // 如果 Messages 是字符串,反序列化为 JSON 数组 - messages := record.Messages - if str, ok := messages.(string); ok && str != "" { - var parsed any - if err := json.Unmarshal([]byte(str), &parsed); err == nil { - messages = parsed - } - } - - return &dto.GetComposeTaskRes{ - TaskId: record.TaskId, - Status: record.Status, - ErrorMessage: record.ErrorMessage, - Messages: messages, - }, nil -} - // GetModelMessage 获取模型信息 -func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { +func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { + userInfo, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, nil, err + } // 1. 获取当前用户的会话模型 - chatModel, err := dao.Model.GetByIsChatModel(ctx) + chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + IsChatModel: 1, + }) if err != nil { return nil, nil, err } @@ -222,18 +148,21 @@ func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMes return nil, nil, errors.New("当前没有对话模型,请添加") } // 2. 获取要构建的模型信息 - model, err := dao.Model.GetByModelName(ctx, req.ModelName) + aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + ModelName: req.ModelName, + }) if err != nil { return nil, nil, err } - if model == nil { + if aiModel == nil { return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName) } - return chatModel, model, nil + return chatModel, aiModel, nil } // callInferenceModel 调用推理模型 -func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { +func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { // 构建推理模型请求 taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history) if err != nil { @@ -241,7 +170,7 @@ func (s *promptService) callInferenceModel(ctx context.Context, req *dto.Compose } // 创建网关任务 - taskID, err := createGatewayTask(ctx, taskReq) + taskID, err := gateway.CreateGatewayTask(ctx, taskReq) if err != nil { return "", fmt.Errorf("创建网关任务失败: %w", err) } @@ -253,10 +182,8 @@ func (s *promptService) callInferenceModel(ctx context.Context, req *dto.Compose return taskID, nil } -// ============================================ -// 步骤6:等待结果 -// ============================================ -func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { +// waitForResult 等待结果 +func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond deadline := time.Now().Add(timeout) @@ -271,7 +198,9 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti } // 1. 查数据库 - record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) + record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: taskID, + }) if err != nil { // ===================== 修复点 2:如果是上下文取消,直接返回 ===================== if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { @@ -292,7 +221,7 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti } // 2. 查网关状态 - state, err := queryGatewayTaskState(ctx, taskID) + state, err := gateway.QueryGatewayTaskState(ctx, taskID) if err != nil { // 网关不可达不终止,继续轮询 g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) @@ -301,16 +230,24 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti case 2: // 网关成功 // 网关已成功,主动更新数据库 if record != nil { - dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ - entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: taskID, + Status: public.ComposeStatusSuccess, }) + if err != nil { + g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) + } } case 3: // 网关失败 if record != nil { - dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ - entity.ComposeTaskCol.Status: public.ComposeStatusFailed, - entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败", + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: taskID, + Status: public.ComposeStatusFailed, + ErrorMessage: "model-gateway 任务执行失败", }) + if err != nil { + g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) + } } return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) } @@ -331,7 +268,7 @@ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*enti } // parsePromptBuild 解析提示词构建结果(BuildType == 1) -func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any { +func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any { if taskRecord == nil { return nil } @@ -394,7 +331,7 @@ func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model * } // parseNodeBuild 解析节点构建结果(BuildType == 2) -func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any { +func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any { if taskRecord == nil { return nil } @@ -414,3 +351,90 @@ func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[strin } return result } + +// Callback 回调处理 +func Callback(ctx context.Context, req *promptDto.CallbackReq) error { + g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", + req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) + + // ============ 先查任务是否存在 ============ + task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + }) + if err != nil { + return err + } + if task == nil { + return fmt.Errorf("任务不存在: %s", req.TaskId) + } + // ============ 根据状态区分处理 ============ + if req.State == 3 { + // 失败:直接更新状态 + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusFailed, + ErrorMessage: req.ErrorMsg, + }) + return err + } + // ====================================== + // 成功:解析模型输出 + result, err := util.ParseOutput(req.Text) + if err != nil { + _, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusFailed, + ErrorMessage: req.ErrorMsg, + }) + if updateErr != nil { + g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr) + } + return err + } + + // ============ result 可能为 nil ============ + var messages any + if result != nil { + messages = result + } + // ======================================= + + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusSuccess, + Messages: messages, + }) + if err != nil { + g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) + } + return err +} + +// GetComposeTask 查询任务结果 +func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) { + record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: taskID, + }) + if err != nil { + return nil, err + } + if record == nil { + return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) + } + + // 如果 Messages 是字符串,反序列化为 JSON 数组 + messages := record.Messages + if str, ok := messages.(string); ok && str != "" { + var parsed any + if err := json.Unmarshal([]byte(str), &parsed); err == nil { + messages = parsed + } + } + + return &promptDto.GetComposeTaskRes{ + TaskId: record.TaskId, + Status: record.Status, + ErrorMessage: record.ErrorMessage, + Messages: messages, + }, nil +} diff --git a/service/files_handle.go b/service/prompt/prompt_files_handle_service.go similarity index 62% rename from service/files_handle.go rename to service/prompt/prompt_files_handle_service.go index 393fdb0..d8aef2f 100644 --- a/service/files_handle.go +++ b/service/prompt/prompt_files_handle_service.go @@ -1,4 +1,4 @@ -package service +package prompt import ( "archive/zip" @@ -7,52 +7,16 @@ import ( "fmt" "io" "net/http" - "path/filepath" - "regexp" "strings" "time" + "prompts-core/common/util" + "prompts-core/service/gateway" + "github.com/gogf/gf/v2/frame/g" ) -// ============================================ -// 文件处理(配置直接内联 + zip 支持) -// ============================================ - -// 允许的文本类 MIME 类型前缀 -var allowedMIMEPrefixes = []string{ - "text/", - "application/json", - "application/xml", - "application/javascript", - "application/x-yaml", - "application/yaml", - "application/toml", - "application/x-httpd-php", - "application/x-sh", - "application/x-python", - "application/x-perl", - "application/x-ruby", -} - -// 禁止的文件扩展名 -var bannedExtensions = map[string]bool{ - ".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true, - ".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true, - ".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true, - ".wma": true, ".m4a": true, - ".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true, - ".flv": true, ".webm": true, - ".tar": true, ".gz": true, ".rar": true, ".7z": true, - ".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true, - ".class": true, ".pyc": true, - ".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true, - ".ppt": true, ".pptx": true, -} - -var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`) - -// FetchFileTexts 从 URL 列表获取文件内容(支持 zip 内文件) +// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件 func FetchFileTexts(ctx context.Context, urls []string) map[string]string { result := make(map[string]string) @@ -65,16 +29,16 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string { } for _, rawURL := range urls { - url := sanitizeURL(rawURL) + url := util.SanitizeURL(rawURL) if url == "" { continue } - if isBannedExtension(url) { + if util.IsBannedExtension(url) { continue } - if isZipExtension(url) { + if util.IsZipExtension(url) { zipTexts := fetchZipFileTexts(ctx, client, url) for k, v := range zipTexts { result[k] = v @@ -91,21 +55,14 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string { continue } - text = cleanSymbols(text) + text = util.CleanSymbols(text) result[url] = text } return result } -func isZipExtension(url string) bool { - ext := strings.ToLower(filepath.Ext(url)) - if idx := strings.Index(ext, "?"); idx != -1 { - ext = ext[:idx] - } - return ext == ".zip" -} - +// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容 func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string { result := make(map[string]string) @@ -130,11 +87,11 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map fileName := file.Name - if isBannedExtension(fileName) { + if util.IsBannedExtension(fileName) { continue } - if isZipExtension(fileName) { + if util.IsZipExtension(fileName) { continue } @@ -150,11 +107,11 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map } contentType := http.DetectContentType(content) - if !isReadableContentType(contentType) { + if !util.IsReadableContentType(contentType) { continue } - text := cleanSymbols(string(content)) + text := util.CleanSymbols(string(content)) if text == "" { continue } @@ -166,6 +123,7 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map return result } +// downloadFile 下载文件,限制最大大小 func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -185,35 +143,7 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error return io.ReadAll(io.LimitReader(resp.Body, maxSize)) } -func isBannedExtension(url string) bool { - ext := strings.ToLower(filepath.Ext(url)) - if idx := strings.Index(ext, "?"); idx != -1 { - ext = ext[:idx] - } - return bannedExtensions[ext] -} - -func isReadableContentType(contentType string) bool { - if contentType == "" { - return false - } - ct := strings.ToLower(contentType) - for _, prefix := range allowedMIMEPrefixes { - if strings.HasPrefix(ct, prefix) { - return true - } - } - return false -} - -func cleanSymbols(text string) string { - text = symbolCleaner.ReplaceAllString(text, "") - text = strings.ReplaceAll(text, "\r\n", "\n") - text = strings.ReplaceAll(text, "\r", "\n") - text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n") - return strings.TrimSpace(text) -} - +// fetchFileContent 获取单个文本文件内容 func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -231,7 +161,7 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str } contentType := resp.Header.Get("Content-Type") - if !isReadableContentType(contentType) { + if !util.IsReadableContentType(contentType) { return "", fmt.Errorf("unreadable content-type: %s", contentType) } @@ -247,22 +177,15 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str return strings.TrimSpace(string(body)), nil } -func sanitizeURL(raw string) string { - s := strings.TrimSpace(raw) - s = strings.Trim(s, "`\"") - return s -} - // SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容 func SkillMdContent(ctx context.Context, skillName string) string { - // 1. 请求接口获取 SkillUserVO - skillResp, err := GetSkillUser(ctx, skillName) + skillResp, err := gateway.GetSkillUser(ctx, skillName) if err != nil { return "" } fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl - // 2. 下载 zip 文件 + client := &http.Client{ Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second, } @@ -274,7 +197,6 @@ func SkillMdContent(ctx context.Context, skillName string) string { return "" } - // 3. 解压 zip 并提取所有 md 文件内容 mdContents, err := extractMdFiles(ctx, zipBytes) if err != nil { return "" @@ -284,7 +206,6 @@ func SkillMdContent(ctx context.Context, skillName string) string { return "" } - // 4. 拼接所有 md 内容 var builder strings.Builder builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name)) if skillResp.Description != "" { diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go new file mode 100644 index 0000000..cdca0ea --- /dev/null +++ b/service/prompt/prompt_ir_service.go @@ -0,0 +1,264 @@ +package prompt + +import ( + "context" + "encoding/json" + "fmt" + "prompts-core/common/util" + "strings" + + "prompts-core/dao" + "prompts-core/model/entity" +) + +// PromptIR 统一 Prompt 中间表示 +type PromptIR struct { + System []Segment `json:"system"` + History []Segment `json:"history"` + User []Segment `json:"user"` +} + +// Segment 消息片段 +type Segment struct { + Type string `json:"type"` // text/image + Content string `json:"content"` + Role string `json:"role,omitempty"` +} + +// NewPromptIR 创建空 PromptIR +func NewPromptIR() *PromptIR { + return &PromptIR{ + System: make([]Segment, 0), + History: make([]Segment, 0), + User: make([]Segment, 0), + } +} + +// AddSystem 添加系统提示 +func (ir *PromptIR) AddSystem(content string) *PromptIR { + if content != "" { + ir.System = append(ir.System, Segment{Type: "text", Content: content}) + } + return ir +} + +// AddUser 添加用户消息 +func (ir *PromptIR) AddUser(content string) *PromptIR { + if content != "" { + ir.User = append(ir.User, Segment{Type: "text", Content: content}) + } + return ir +} + +// AddHistory 添加历史消息 +func (ir *PromptIR) AddHistory(role, content string) *PromptIR { + if content != "" { + ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) + } + return ir +} + +// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认) +func (ir *PromptIR) ToMessages() []map[string]any { + var messages []map[string]any + + // 1. 系统消息 + for _, seg := range ir.System { + messages = append(messages, map[string]any{ + "role": "system", + "content": seg.Content, + }) + } + + // 2. 历史消息 + for _, seg := range ir.History { + messages = append(messages, map[string]any{ + "role": seg.Role, + "content": seg.Content, + }) + } + + // 3. 用户消息 + for _, seg := range ir.User { + messages = append(messages, map[string]any{ + "role": "user", + "content": seg.Content, + }) + } + return messages +} + +// GetProtocolByProvider 根据 provider_name 获取协议配置 +func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) { + entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ + ProviderName: providerName, + Status: 1, + }) + if err != nil || entity == nil { + return nil, err + } + entity.MergeOrder = util.ParseJSONField(entity.MergeOrder) + entity.RoleMapping = util.ParseJSONField(entity.RoleMapping) + entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) + entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate) + entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) + return parseProtocol(entity), nil +} + +// parseProtocol 将 DB entity 转为编译用协议配置 +func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { + p := &ProviderProtocol{ + TargetField: e.TargetField, + } + + // MergeOrder: any → []string + if e.MergeOrder != nil { + b, _ := json.Marshal(e.MergeOrder) + json.Unmarshal(b, &p.MergeOrder) + } + + // RoleMapping: any → map[string]string + if e.RoleMapping != nil { + b, _ := json.Marshal(e.RoleMapping) + json.Unmarshal(b, &p.RoleMapping) + } + + // ContentMapping: any → ContentMapping + if e.ContentMapping != nil { + b, _ := json.Marshal(e.ContentMapping) + json.Unmarshal(b, &p.ContentMapping) + } + + // RequestTemplate: any → map[string]any + if e.RequestTemplate != nil { + b, _ := json.Marshal(e.RequestTemplate) + json.Unmarshal(b, &p.RequestTemplate) + } + fmt.Printf("parseProtocol: %+v\n", p) + return p +} + +// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) +type ProviderProtocol struct { + TargetField string `json:"target_field"` + MergeOrder []string `json:"merge_order"` + RoleMapping map[string]string `json:"role_mapping"` + ContentMapping ContentMapping `json:"content_mapping"` + RequestTemplate map[string]any `json:"request_template"` +} + +// ContentMapping 内容字段映射 +type ContentMapping struct { + Type string `json:"type"` // direct/parts + Field string `json:"field"` // content/text +} + +// Compile 将 PromptIR 按协议配置编译为 Provider Request +func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) { + if ir == nil || p == nil { + return nil, fmt.Errorf("ir and protocol are required") + } + // 1. 按 merge_order 拼接消息 + messages := mergeByOrder(ir, p.MergeOrder) + // 2. 角色映射 + messages = mapRoles(messages, p.RoleMapping) + // 3. 内容字段映射 + messages = mapContent(messages, p.ContentMapping) + // 4. 按 target_field + request_template 构建请求体 + return buildRequest(messages, p, chatModel), nil +} + +// mergeByOrder 按协议配置顺序拼接消息 +func mergeByOrder(ir *PromptIR, order []string) []map[string]any { + var messages []map[string]any + + for _, part := range order { + switch part { + case "system": + for _, seg := range ir.System { + messages = append(messages, map[string]any{ + "role": "system", + "content": seg.Content, + }) + } + case "history": + for _, seg := range ir.History { + messages = append(messages, map[string]any{ + "role": seg.Role, + "content": seg.Content, + }) + } + case "user": + for _, seg := range ir.User { + messages = append(messages, map[string]any{ + "role": "user", + "content": seg.Content, + }) + } + } + } + return messages +} + +// mapRoles 角色映射 +func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any { + if len(mapping) == 0 { + return messages + } + for i, msg := range messages { + role, ok := msg["role"].(string) + if !ok { + continue + } + if mapped, exists := mapping[role]; exists { + messages[i]["role"] = mapped + } + } + return messages +} + +// mapContent 内容字段映射 +func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { + for _, msg := range messages { + content := msg["content"] + delete(msg, "content") + + switch cm.Type { + case "parts": + // Gemini 格式: {"parts": [{"text": "..."}]} + msg["parts"] = []map[string]any{ + {cm.Field: content}, + } + default: + // direct: {"content": "..."} + msg[cm.Field] = content + } + } + return messages +} + +// buildRequest 按 target_field 和 request_template 构建请求体 +func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any { + if len(p.RequestTemplate) > 0 { + return renderTemplate(p.RequestTemplate, messages, chatModel) + } + return map[string]any{ + p.TargetField: messages, + } +} + +// renderTemplate 简单的 {{key}} 模板替换 +func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any { + b, _ := json.Marshal(tmpl) + str := string(b) + + // 替换 {{model}} + str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) + // 替换 {{messages}} + msgBytes, _ := json.Marshal(messages) + str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) + + var result map[string]any + json.Unmarshal([]byte(str), &result) + return result +} diff --git a/service/session_redis_service.go b/service/prompt/prompt_session_redis_service.go similarity index 85% rename from service/session_redis_service.go rename to service/prompt/prompt_session_redis_service.go index d854c66..16ebb17 100644 --- a/service/session_redis_service.go +++ b/service/prompt/prompt_session_redis_service.go @@ -1,4 +1,4 @@ -package service +package prompt import ( "context" @@ -12,7 +12,7 @@ import ( // ==================== Redis 操作 ==================== // saveToRedis 保存会话数据到Redis -func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { +func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { key := fmt.Sprintf("chat:session:%s", sessionId) maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() @@ -50,7 +50,7 @@ func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requ } // getFromRedis 从Redis获取会话历史 -func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { +func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { key := fmt.Sprintf("chat:session:%s", sessionId) result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) @@ -82,8 +82,8 @@ func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([] } // GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用) -func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) { - historyData, err := s.getFromRedis(ctx, sessionId) +func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) { + historyData, err := getFromRedis(ctx, sessionId) if err != nil { return nil, fmt.Errorf("获取历史会话失败: %w", err) } diff --git a/service/session_service.go b/service/prompt/prompt_session_service.go similarity index 60% rename from service/session_service.go rename to service/prompt/prompt_session_service.go index cae70bc..b0d9fce 100644 --- a/service/session_service.go +++ b/service/prompt/prompt_session_service.go @@ -1,24 +1,22 @@ -package service +package prompt import ( "context" "fmt" - "prompts-core/dao" - "prompts-core/model/dto" + sessionDao "prompts-core/dao" "prompts-core/model/entity" + "prompts-core/common/util" + sessionDto "prompts-core/model/dto/prompt" + "gitea.com/red-future/common/beans" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) -var Session = &sessionService{} - -type sessionService struct{} - -func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) { +func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) { // 1. 解析AI返回的文本 - result, err := parseOutput(req.Text) + result, err := util.ParseOutput(req.Text) if err != nil { g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, err @@ -26,7 +24,7 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa // 2. 更新数据库 result["role"] = "assistant" - _, err = dao.ComposeSession.Update(ctx, &entity.ComposeSession{ + _, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, ResponseContent: result, }) @@ -36,17 +34,19 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa } // 3. 获取当前轮次完整数据 - session, err := dao.ComposeSession.GetById(ctx, req.EpicycleId) + session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{ + SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, + }) if err != nil { g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, err } // 4. 转换 json 并存入 Redis - requestMessages := convertToMessages(session.RequestContent) - responseMessages := convertToMessages(session.ResponseContent) + requestMessages := util.ConvertToMessages(session.RequestContent) + responseMessages := util.ConvertToMessages(session.ResponseContent) - if err = s.saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { + if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", session.SessionId, session.Id, err) return nil, err @@ -54,21 +54,23 @@ func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCa g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", session.SessionId, session.Id, len(requestMessages), len(responseMessages)) - return &beans.ResponseEmpty{}, nil + return &sessionDto.SessionCallbackRes{}, nil } // GetHistoryMessages 获取历史信息 -func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { +func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() // 1. 先从 Redis 拿 - redisHistory, err := s.GetSessionHistoryForInference(ctx, sessionId) + redisHistory, err := GetSessionHistoryForInference(ctx, sessionId) if err == nil && len(redisHistory) > 0 { return redisHistory, nil } // 2. Redis 没有 → fallback DB - sessions, err := dao.ComposeSession.GetListBySessionId(ctx, sessionId, maxRounds) + sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{ + SessionId: sessionId, + }, 1, maxRounds) if err != nil { return nil, fmt.Errorf("DB获取历史失败: %w", err) } @@ -77,7 +79,7 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin for _, session := range sessions { // request - reqMsgs := convertToMessages(session.RequestContent) + reqMsgs := util.ConvertToMessages(session.RequestContent) for _, m := range reqMsgs { role := gconv.String(m["role"]) if role == "user" || role == "assistant" { @@ -86,7 +88,7 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin } // response - respMsgs := convertToMessages(session.ResponseContent) + respMsgs := util.ConvertToMessages(session.ResponseContent) for _, m := range respMsgs { if m["role"] == nil { m["role"] = "assistant" @@ -97,15 +99,15 @@ func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId strin // 3. 回写 Redis for _, session := range sessions { - reqMsgs := convertToMessages(session.RequestContent) - respMsgs := convertToMessages(session.ResponseContent) + reqMsgs := util.ConvertToMessages(session.RequestContent) + respMsgs := util.ConvertToMessages(session.ResponseContent) for i := range respMsgs { if respMsgs[i]["role"] == nil { respMsgs[i]["role"] = "assistant" } } if len(reqMsgs) > 0 || len(respMsgs) > 0 { - _ = s.saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) + _ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) } } return messages, nil diff --git a/service/prompt_service.go b/service/prompt_service.go deleted file mode 100644 index 5e35080..0000000 --- a/service/prompt_service.go +++ /dev/null @@ -1,92 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - - "prompts-core/dao" - "prompts-core/model/dto" - "prompts-core/model/entity" -) - -var Prompt = &promptService{} - -type promptService struct{} - -func (s *promptService) Create(ctx context.Context, req *dto.CreatePromptReq) (res *dto.CreatePromptRes, err error) { - // promptInfo 兜底校验:必须可序列化为 JSON - if req.PromptInfo == nil { - return nil, errors.New("promptInfo不能为空") - } - if _, err := json.Marshal(req.PromptInfo); err != nil { - return nil, errors.New("promptInfo不是合法JSON") - } - if req.ResponseJsonSchema == nil { - return nil, errors.New("responseJsonSchema不能为空") - } - if _, err := json.Marshal(req.ResponseJsonSchema); err != nil { - return nil, errors.New("responseJsonSchema不是合法JSON") - } - - m := &entity.PromptConfig{ - ModelTypeId: req.ModelTypeId, - ModelType: req.ModelType, - PromptInfo: req.PromptInfo, - ResponseJsonSchema: req.ResponseJsonSchema, - Enabled: 1, - Version: req.Version, - } - - id, err := dao.Prompt.Insert(ctx, m) - if err != nil { - return nil, err - } - return &dto.CreatePromptRes{ID: id}, nil -} - -func (s *promptService) Update(ctx context.Context, req *dto.UpdatePromptReq) error { - data := map[string]any{} - if req.ModelTypeId != nil && *req.ModelTypeId > 0 { - data[entity.PromptConfigCol.ModelTypeId] = *req.ModelTypeId - } - if req.ModelType != nil && *req.ModelType != "" { - data[entity.PromptConfigCol.ModelType] = *req.ModelType - } - if req.PromptInfo != nil { - if _, err := json.Marshal(req.PromptInfo); err != nil { - return errors.New("promptInfo不是合法JSON") - } - data[entity.PromptConfigCol.PromptInfo] = req.PromptInfo - } - if req.ResponseJsonSchema != nil { - if _, err := json.Marshal(req.ResponseJsonSchema); err != nil { - return errors.New("responseJsonSchema不是合法JSON") - } - data[entity.PromptConfigCol.ResponseJsonSchema] = req.ResponseJsonSchema - } - if req.Enabled != nil { - data[entity.PromptConfigCol.Enabled] = *req.Enabled - } - if req.Version != nil { - data[entity.PromptConfigCol.Version] = *req.Version - } - if len(data) == 0 { - return errors.New("无可更新字段") - } - _, err := dao.Prompt.UpdateByID(ctx, req.ID, data) - return err -} - -func (s *promptService) Delete(ctx context.Context, id int64) error { - _, err := dao.Prompt.DeleteByID(ctx, id) - return err -} - -func (s *promptService) Get(ctx context.Context, id int64) (*entity.PromptConfig, error) { - return dao.Prompt.GetByID(ctx, id) -} - -func (s *promptService) List(ctx context.Context, pageNum, pageSize int, modelTypeID *int, modelTypeLike string) (list []*entity.PromptConfig, total int64, err error) { - return dao.Prompt.List(ctx, pageNum, pageSize, modelTypeID, modelTypeLike) -} diff --git a/service/utils.go b/service/utils.go deleted file mode 100644 index 37e30dc..0000000 --- a/service/utils.go +++ /dev/null @@ -1,65 +0,0 @@ -// utils 工具函数 -package service - -import ( - "encoding/json" - "fmt" - - "github.com/gogf/gf/v2/encoding/gjson" - "github.com/gogf/gf/v2/util/gconv" -) - -// ============================================ -// json 相关处理 -// ============================================ -// parseOutput 解析模型输出为 JSON 格式 -func parseOutput(text string) (map[string]any, error) { - j, err := gjson.LoadJson([]byte(text)) - if err != nil { - return nil, fmt.Errorf("解析模型输出失败: %w", err) - } - - return j.Map(), nil -} - -func convertToMessages(raw any) []map[string]any { - if raw == nil { - return nil - } - j, err := gjson.LoadJson(gconv.Bytes(raw)) - if err != nil { - return nil - } - // 1. 如果有 messages - if j.Contains("messages") { - return gconv.Maps(j.Get("messages").Array()) - } - // 2. 否则当成单条 message - return []map[string]any{ - j.Map(), - } -} - -// isMessageValid 校验推理结果是否合法 -func isMessageValid(message map[string]any) bool { - if message == nil { - return false - } - return true -} - -func formToJSON(form map[string]any) string { - if form == nil { - return "{}" - } - b, _ := json.Marshal(form) - return string(b) -} - -func mustMarshal(v any) string { - b, err := json.Marshal(v) - if err != nil { - return "{}" - } - return string(b) -} diff --git a/update.sql b/update.sql index 3217134..ac9e6b1 100644 --- a/update.sql +++ b/update.sql @@ -1,117 +1,130 @@ --- prompts-core 核心表(pgsql) --- 说明:字段风格尽量与参考项目一致(tenant/creator/updater/created_at/updated_at/deleted_at) - --- prompts_model_prompt 模型提示词配置表 -CREATE TABLE IF NOT EXISTS prompts_model_prompt ( - -- 基础字段(与 common/db/gfdb 的 Hook 约定保持一致) - id BIGINT PRIMARY KEY, -- 主键ID(非自增) - tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID - creator VARCHAR(64) NOT NULL, -- 创建人 - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 - updater VARCHAR(64) NOT NULL, -- 更新人 - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 - deleted_at TIMESTAMP(6), -- 删除时间(软删) - - -- 业务字段(按你当前的最小字段集) - model_type_id INT NOT NULL DEFAULT 0, -- 模型分类ID - model_type VARCHAR(64) NOT NULL, -- 模型类别 - prompt_info JSONB NOT NULL DEFAULT '{}'::jsonb, -- 提示词信息(JSON) - response_json_schema JSONB NOT NULL DEFAULT '{}'::jsonb, -- 模型返回表单 JSON 格式约束 - enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用:1启用/0禁用 - version VARCHAR(64) NOT NULL DEFAULT '' -- 版本号(预留) -); - -CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_tenant_id ON prompts_model_prompt(tenant_id); -CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type_id ON prompts_model_prompt(model_type_id); -CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_model_type ON prompts_model_prompt(model_type); -CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_enabled ON prompts_model_prompt(enabled); -CREATE INDEX IF NOT EXISTS idx_prompts_model_prompt_deleted_at ON prompts_model_prompt(deleted_at); - -COMMENT ON TABLE prompts_model_prompt IS '模型提示词配置表'; -COMMENT ON COLUMN prompts_model_prompt.id IS '主键ID(非自增)'; -COMMENT ON COLUMN prompts_model_prompt.tenant_id IS '租户ID'; -COMMENT ON COLUMN prompts_model_prompt.creator IS '创建人'; -COMMENT ON COLUMN prompts_model_prompt.created_at IS '创建时间'; -COMMENT ON COLUMN prompts_model_prompt.updater IS '更新人'; -COMMENT ON COLUMN prompts_model_prompt.updated_at IS '更新时间'; -COMMENT ON COLUMN prompts_model_prompt.deleted_at IS '删除时间(软删)'; -COMMENT ON COLUMN prompts_model_prompt.model_type_id IS '模型分类ID'; -COMMENT ON COLUMN prompts_model_prompt.model_type IS '模型类别'; -COMMENT ON COLUMN prompts_model_prompt.prompt_info IS '提示词信息(JSON)'; -COMMENT ON COLUMN prompts_model_prompt.response_json_schema IS '模型返回表单 JSON 格式约束'; -COMMENT ON COLUMN prompts_model_prompt.enabled IS '是否启用:1启用/0禁用'; -COMMENT ON COLUMN prompts_model_prompt.version IS '版本号(预留)'; - -- prompts_compose_task 拼接提示词任务记录表 CREATE TABLE IF NOT EXISTS prompts_compose_task ( - id BIGINT PRIMARY KEY, - tenant_id BIGINT NOT NULL DEFAULT 0, - creator VARCHAR(64) NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updater VARCHAR(64) NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - deleted_at TIMESTAMP(6), + id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL DEFAULT 0, + creator VARCHAR(64) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updater VARCHAR(64) NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP(6), - task_id VARCHAR(64) NOT NULL, - model_name VARCHAR(128) NOT NULL DEFAULT '', - skill_name VARCHAR(128) NOT NULL DEFAULT '', - gateway_state INT NOT NULL DEFAULT 0, - limit_words INT NOT NULL DEFAULT 0, - request_payload JSONB NOT NULL DEFAULT '{}'::jsonb, - result_text TEXT NOT NULL DEFAULT '', - messages JSONB NOT NULL DEFAULT '[]'::jsonb, - status VARCHAR(32) NOT NULL DEFAULT 'pending', - error_message TEXT NOT NULL DEFAULT '', - oss_file VARCHAR(1024) NOT NULL DEFAULT '', - file_type VARCHAR(64) NOT NULL DEFAULT '' + task_id VARCHAR(64) NOT NULL, + model_name VARCHAR(128) NOT NULL DEFAULT '', + skill_name VARCHAR(128) NOT NULL DEFAULT '', + gateway_state INT NOT NULL DEFAULT 0, + limit_words INT NOT NULL DEFAULT 0, + request_payload JSONB NOT NULL DEFAULT '{}'::jsonb, + result_text TEXT NOT NULL DEFAULT '', + messages JSONB NOT NULL DEFAULT '{}'::jsonb, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + error_message TEXT NOT NULL DEFAULT '', + oss_file VARCHAR(1024) NOT NULL DEFAULT '', + file_type VARCHAR(64) NOT NULL DEFAULT '' ); - +-- 索引 CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id); CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status); -CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task(deleted_at); - -COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表'; -COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID'; -COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称'; -COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称'; -COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态:0排队/1执行/2成功/3失败/4已下载'; -COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数'; +CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task +-- 注释 +COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表'; +COMMENT ON COLUMN prompts_compose_task.id IS '主键ID'; +COMMENT ON COLUMN prompts_compose_task.tenant_id IS '租户ID'; +COMMENT ON COLUMN prompts_compose_task.creator IS '创建人'; +COMMENT ON COLUMN prompts_compose_task.created_at IS '创建时间'; +COMMENT ON COLUMN prompts_compose_task.updater IS '更新人'; +COMMENT ON COLUMN prompts_compose_task.updated_at IS '更新时间'; +COMMENT ON COLUMN prompts_compose_task.deleted_at IS '删除时间(软删)'; +COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID'; +COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称'; +COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称'; +COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态:0排队/1执行/2成功/3失败/4已下载'; +COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数'; COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容'; -COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果'; -COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages'; -COMMENT ON COLUMN prompts_compose_task.status IS '业务状态:pending/success/failed'; -COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息'; -COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址'; -COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型'; +COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果'; +COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages'; +COMMENT ON COLUMN prompts_compose_task.status IS '业务状态:pending/success/failed'; +COMMENT ON COLUMN prompts_compose_task.error_message IS '业务错误信息'; +COMMENT ON COLUMN prompts_compose_task.oss_file IS '网关返回的结果文件地址'; +COMMENT ON COLUMN prompts_compose_task.file_type IS '结果文件类型'; + + -- prompts_compose_session 提示词历史会话表 CREATE TABLE IF NOT EXISTS prompts_compose_session ( - id BIGINT PRIMARY KEY, - tenant_id BIGINT NOT NULL DEFAULT 0, - creator VARCHAR(64) NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updater VARCHAR(64) NOT NULL, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - deleted_at TIMESTAMP(6), + id BIGINT NOT NULL, + tenant_id BIGINT NOT NULL DEFAULT 0, + creator VARCHAR(64) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updater VARCHAR(64) NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP(6), - session_id VARCHAR(64) NOT NULL, - request_content JSONB NOT NULL DEFAULT '{}'::jsonb, - response_content JSONB NOT NULL DEFAULT '{}'::jsonb, - remark VARCHAR(500) NOT NULL DEFAULT '' + session_id VARCHAR(64) NOT NULL, + request_content JSONB NOT NULL DEFAULT '{}'::jsonb, + response_content JSONB NOT NULL DEFAULT '{}'::jsonb, + remark VARCHAR(500) NOT NULL DEFAULT '' ); - +-- 索引 CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_session_id ON prompts_compose_session(session_id); CREATE INDEX IF NOT EXISTS idx_prompts_compose_session_deleted_at ON prompts_compose_session(deleted_at); - -COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表'; -COMMENT ON COLUMN prompts_compose_session.id IS '主键ID(非自增)'; -COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID'; -COMMENT ON COLUMN prompts_compose_session.creator IS '创建人'; -COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间'; -COMMENT ON COLUMN prompts_compose_session.updater IS '更新人'; -COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间'; -COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)'; -COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID'; -COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容(JSON格式)'; +-- 注释 +COMMENT ON TABLE prompts_compose_session IS '提示词历史会话表'; +COMMENT ON COLUMN prompts_compose_session.id IS '主键ID'; +COMMENT ON COLUMN prompts_compose_session.tenant_id IS '租户ID'; +COMMENT ON COLUMN prompts_compose_session.creator IS '创建人'; +COMMENT ON COLUMN prompts_compose_session.created_at IS '创建时间'; +COMMENT ON COLUMN prompts_compose_session.updater IS '更新人'; +COMMENT ON COLUMN prompts_compose_session.updated_at IS '更新时间'; +COMMENT ON COLUMN prompts_compose_session.deleted_at IS '删除时间(软删)'; +COMMENT ON COLUMN prompts_compose_session.session_id IS '会话ID'; +COMMENT ON COLUMN prompts_compose_session.request_content IS '请求内容(JSON格式)'; COMMENT ON COLUMN prompts_compose_session.response_content IS '返回内容(JSON格式)'; -COMMENT ON COLUMN prompts_compose_session.remark IS '备注'; \ No newline at end of file +COMMENT ON COLUMN prompts_compose_session.remark IS '备注'; + + + +-- prompts_provider_protocol 模型协议映射配置表 +CREATE TABLE IF NOT EXISTS prompts_provider_protocol ( + id BIGINT PRIMARY KEY, + tenant_id BIGINT NOT NULL DEFAULT 0, + creator VARCHAR(64) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updater VARCHAR(64) NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP(6), + + provider_name VARCHAR(64) NOT NULL DEFAULT '', + target_field VARCHAR(64) NOT NULL DEFAULT '', + merge_order JSONB NOT NULL DEFAULT '[]'::jsonb, + role_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, + content_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, + capabilities JSONB NOT NULL DEFAULT '{}'::jsonb, + request_template JSONB NOT NULL DEFAULT '{}'::jsonb, + system_prompt_template TEXT NOT NULL DEFAULT '', + user_prompt_template TEXT NOT NULL DEFAULT '', + status INT NOT NULL DEFAULT 1, + remark VARCHAR(500) NOT NULL DEFAULT '' +); +-- 索引 +CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_provider_name ON prompts_provider_protocol(provider_name); +CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_status ON prompts_provider_protocol(status); +CREATE INDEX IF NOT EXISTS idx_prompts_provider_protocol_deleted_at ON prompts_provider_protocol(deleted_at); +-- 注释 +COMMENT ON TABLE prompts_provider_protocol IS '模型协议映射配置表'; +COMMENT ON COLUMN prompts_provider_protocol.id IS '主键ID'; +COMMENT ON COLUMN prompts_provider_protocol.tenant_id IS '租户ID'; +COMMENT ON COLUMN prompts_provider_protocol.creator IS '创建人'; +COMMENT ON COLUMN prompts_provider_protocol.created_at IS '创建时间'; +COMMENT ON COLUMN prompts_provider_protocol.updater IS '更新人'; +COMMENT ON COLUMN prompts_provider_protocol.updated_at IS '更新时间'; +COMMENT ON COLUMN prompts_provider_protocol.deleted_at IS '删除时间(软删)'; +COMMENT ON COLUMN prompts_provider_protocol.provider_name IS '运营商名称(openai/deepseek/qwen/anthropic/gemini等)'; +COMMENT ON COLUMN prompts_provider_protocol.target_field IS '目标字段(messages/contents/prompt)'; +COMMENT ON COLUMN prompts_provider_protocol.merge_order IS 'Prompt IR 拼接顺序(system/history/user)'; +COMMENT ON COLUMN prompts_provider_protocol.role_mapping IS '角色映射(system/user/assistant -> provider role)'; +COMMENT ON COLUMN prompts_provider_protocol.content_mapping IS '内容字段映射(content/parts.text等)'; +COMMENT ON COLUMN prompts_provider_protocol.capabilities IS '协议能力配置(system/history/tools/stream等支持情况)'; +COMMENT ON COLUMN prompts_provider_protocol.request_template IS '请求模板(JSON结构模板)'; +COMMENT ON COLUMN prompts_provider_protocol.system_prompt_template IS '系统提示词模板'; +COMMENT ON COLUMN prompts_provider_protocol.status IS '状态:1启用/0禁用'; +COMMENT ON COLUMN prompts_provider_protocol.remark IS '备注'; \ No newline at end of file