refactor(service): 重构服务模块结构并优化模型配置
This commit is contained in:
@@ -8,6 +8,11 @@ import (
|
|||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetServerName 获取服务名称
|
||||||
|
func GetServerName(ctx context.Context) string {
|
||||||
|
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||||
|
}
|
||||||
|
|
||||||
// GetServerPort 从配置获取服务端口
|
// GetServerPort 从配置获取服务端口
|
||||||
func GetServerPort(ctx context.Context) string {
|
func GetServerPort(ctx context.Context) string {
|
||||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||||
|
|||||||
@@ -2,51 +2,34 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
"github.com/gogf/gf/v2/container/gvar"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
|
||||||
tGjson "github.com/tidwall/gjson"
|
tGjson "github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseOutput 解析模型输出为 JSON 格式
|
|
||||||
func ParseOutput(text string) (map[string]any, error) {
|
|
||||||
j, err := gjson.LoadJson([]byte(text))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return j.Map(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertToMessages 将原始数据转换为消息列表
|
// ConvertToMessages 将原始数据转换为消息列表
|
||||||
func ConvertToMessages(raw any) []map[string]any {
|
func ConvertToMessages(raw any) []map[string]any {
|
||||||
if raw == nil {
|
if raw == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
j := gfgjson.New(raw)
|
||||||
if err != nil {
|
messages := j.Get("messages")
|
||||||
return nil
|
if !messages.IsNil() {
|
||||||
|
return gconv.Maps(messages.Val())
|
||||||
}
|
}
|
||||||
|
|
||||||
if j.Contains("messages") {
|
|
||||||
return gconv.Maps(j.Get("messages").Array())
|
|
||||||
}
|
|
||||||
|
|
||||||
return []map[string]any{j.Map()}
|
return []map[string]any{j.Map()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||||
func FormToJSON(form map[string]any) string {
|
func FormToJSON(form []map[string]any) string {
|
||||||
if form == nil {
|
if form == nil {
|
||||||
return "{}"
|
return "[]"
|
||||||
}
|
}
|
||||||
|
|
||||||
b, _ := json.Marshal(form)
|
b, _ := json.Marshal(form)
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|||||||
148
common/util/mapping.go
Normal file
148
common/util/mapping.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"prompts-core/model/entity"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||||
|
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||||||
|
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||||
|
// 1) 获取校验配置,并取值
|
||||||
|
requestMapping := model.RequestMapping
|
||||||
|
contentKey := ""
|
||||||
|
for k := range model.ResponseBody {
|
||||||
|
contentKey = k
|
||||||
|
break
|
||||||
|
}
|
||||||
|
contentStr, ok := raw[contentKey].(string)
|
||||||
|
if !ok || contentStr == "" {
|
||||||
|
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 解析 content 为 JSON 数组
|
||||||
|
var rounds []map[string]any
|
||||||
|
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||||||
|
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||||||
|
}
|
||||||
|
if len(rounds) == 0 {
|
||||||
|
return fmt.Errorf("content 数组为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||||||
|
for i, round := range rounds {
|
||||||
|
for path, defaultValue := range requestMapping {
|
||||||
|
if !g.IsEmpty(defaultValue) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gjson.New(round).Get(path).IsNil() {
|
||||||
|
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseMap 映射 payload 到 mapping
|
||||||
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||||
|
jsonObj := gjson.New("{}")
|
||||||
|
for path, defaultValue := range mapping {
|
||||||
|
val := gjson.New(payload).Get(path)
|
||||||
|
if !val.IsNil() {
|
||||||
|
_ = jsonObj.Set(path, val.Val())
|
||||||
|
} else if defaultValue != nil {
|
||||||
|
_ = jsonObj.Set(path, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonObj.Map()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapResponsePayload 映射模型响应为标准格式
|
||||||
|
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||||
|
if len(mapping) == 0 {
|
||||||
|
return responseBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
responseJson := gjson.New(responseBytes)
|
||||||
|
resultJson := gjson.New("{}")
|
||||||
|
|
||||||
|
for standardField, modelPath := range mapping {
|
||||||
|
path := gconv.String(modelPath)
|
||||||
|
if path == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := responseJson.Get(path)
|
||||||
|
if val.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resultJson.Set(standardField, val.Val())
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(resultJson.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||||
|
// 示例:
|
||||||
|
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||||
|
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||||
|
//
|
||||||
|
// 说明:
|
||||||
|
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||||
|
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||||
|
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||||
|
headMsg = strings.TrimSpace(headMsg)
|
||||||
|
if headMsg == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]string{}
|
||||||
|
parts := strings.Split(headMsg, ",")
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||||
|
if strings.Contains(p, ":") {
|
||||||
|
kv := strings.SplitN(p, ":", 2)
|
||||||
|
k := strings.TrimSpace(kv[0])
|
||||||
|
v := strings.TrimSpace(kv[1])
|
||||||
|
v = strings.Trim(v, "\"")
|
||||||
|
if k != "" && v != "" {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(p, "=") {
|
||||||
|
kv := strings.SplitN(p, "=", 2)
|
||||||
|
k := strings.TrimSpace(kv[0])
|
||||||
|
v := strings.TrimSpace(kv[1])
|
||||||
|
v = strings.Trim(v, "\"")
|
||||||
|
if k != "" && v != "" {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayloadToQuery 将 payload 转为 url.Values
|
||||||
|
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||||
|
q := url.Values{}
|
||||||
|
for k, v := range payload {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
q.Set(k, gconv.String(v))
|
||||||
|
}
|
||||||
|
return q, nil
|
||||||
|
}
|
||||||
@@ -26,3 +26,15 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
|
|||||||
err = r.Struct(&m)
|
err = r.Struct(&m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetsByModelName 批量获取模型
|
||||||
|
func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) {
|
||||||
|
err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
|
OmitEmpty().
|
||||||
|
Where(entity.AsynchModelCol.Creator, creator).
|
||||||
|
WhereIn(entity.AsynchModelCol.ModelName, modelNames).
|
||||||
|
Fields(fields).
|
||||||
|
Scan(&list)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ type ComposeMessagesReq struct {
|
|||||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
|
||||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
Form []map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||||
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||||
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
|
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
|
||||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
type UserPromptPayload struct {
|
type UserPromptPayload struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
PromptInfo string `json:"promptInfo"`
|
PromptInfo string `json:"promptInfo"`
|
||||||
Form map[string]any `json:"form"`
|
Form any `json:"form"`
|
||||||
UserForm any `json:"userForm"`
|
UserForm any `json:"userForm"`
|
||||||
Consult []dto.ConsultItem `json:"consult"`
|
Consult []dto.ConsultItem `json:"consult"`
|
||||||
UserFilesText map[string]string `json:"userFilesText"`
|
UserFilesText map[string]string `json:"userFilesText"`
|
||||||
@@ -30,6 +30,7 @@ type UserPromptPayload struct {
|
|||||||
|
|
||||||
// buildInferenceRequest 构建推理请求
|
// buildInferenceRequest 构建推理请求
|
||||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||||
|
//1) 处理表单分批
|
||||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||||
@@ -47,9 +48,10 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
|
|||||||
|
|
||||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||||
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
|
//1) 构建系统提示词
|
||||||
|
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
|
//2) 构建历史对话
|
||||||
for _, msg := range history {
|
for _, msg := range history {
|
||||||
role := gconv.String(msg["role"])
|
role := gconv.String(msg["role"])
|
||||||
if role != "user" && role != "assistant" {
|
if role != "user" && role != "assistant" {
|
||||||
@@ -57,7 +59,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
|||||||
}
|
}
|
||||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
ir.AddHistory(role, gconv.String(msg["content"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||||||
ir.AddUser(userPrompt)
|
ir.AddUser(userPrompt)
|
||||||
if !checkOverallContent(ir, aiModel) {
|
if !checkOverallContent(ir, aiModel) {
|
||||||
@@ -70,7 +71,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
|||||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||||
ir.AddUser(NodeBuild(ctx, req))
|
ir.AddUser(NodeBuild(ctx, req))
|
||||||
|
|
||||||
return compileToProviderRequest(ctx, ir, chatModel)
|
return compileToProviderRequest(ctx, ir, chatModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,33 +90,33 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
|
|||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"modelName": chatModel.ModelName,
|
"modelName": chatModel.ModelName,
|
||||||
"bizName": "prompts-core",
|
"bizName": util.GetServerName(ctx),
|
||||||
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
|
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
|
||||||
"requestPayload": providerReq,
|
"requestPayload": providerReq,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
|
// promptBuildWithRounds 构建系统提示词
|
||||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
|
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: model.OperatorName,
|
ProviderName: chatModel.OperatorName,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
})
|
})
|
||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
outputJSON := util.JSONPretty(model.RequestMapping)
|
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||||
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
|
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
|
||||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
|
||||||
|
formContent := buildUserFormContent(req.Form)
|
||||||
userFormContent := buildUserFormContent(req.UserForm)
|
userFormContent := buildUserFormContent(req.UserForm)
|
||||||
formInfo := fmt.Sprintf(`
|
formInfo := fmt.Sprintf(`
|
||||||
【系统表单(系统提示词/参数)】
|
【系统表单(系统提示词/参数)】
|
||||||
%s
|
%s
|
||||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||||||
%s
|
%s
|
||||||
`, util.FormToJSON(req.Form), userFormContent)
|
`, formContent, userFormContent)
|
||||||
|
|
||||||
inputInfo := fmt.Sprintf(`
|
inputInfo := fmt.Sprintf(`
|
||||||
目标模型: %s
|
目标模型: %s
|
||||||
@@ -129,11 +129,8 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, mod
|
|||||||
req.ModelName, // %s 目标模型名称
|
req.ModelName, // %s 目标模型名称
|
||||||
maxWindowSize, // %d 最大窗口
|
maxWindowSize, // %d 最大窗口
|
||||||
availableWindow, // %d 可用窗口
|
availableWindow, // %d 可用窗口
|
||||||
totalRounds, // %d 数组长度(多轮输出要求)
|
|
||||||
totalRounds, // %d 数组长度(结构铁律)
|
|
||||||
outputJSON, // %s 输出结构
|
outputJSON, // %s 输出结构
|
||||||
inputInfo, // %s 完整输入信息
|
inputInfo, // %s 完整输入信息
|
||||||
totalRounds, // %d 数组长度(最后一行)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +154,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
|||||||
payload := UserPromptPayload{
|
payload := UserPromptPayload{
|
||||||
Model: req.ModelName,
|
Model: req.ModelName,
|
||||||
PromptInfo: prompt,
|
PromptInfo: prompt,
|
||||||
Form: req.Form,
|
Form: prepareUserFormPayload(req.Form),
|
||||||
UserForm: prepareUserFormPayload(req.UserForm),
|
UserForm: prepareUserFormPayload(req.UserForm),
|
||||||
Consult: req.Consult,
|
Consult: req.Consult,
|
||||||
UserFilesText: ExtractFileTexts(ctx, req.Consult),
|
UserFilesText: ExtractFileTexts(ctx, req.Consult),
|
||||||
|
|||||||
@@ -21,13 +21,16 @@ import (
|
|||||||
|
|
||||||
// ComposeMessages 核心拼接提示词主流程
|
// ComposeMessages 核心拼接提示词主流程
|
||||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||||
|
//1) 获取模型信息
|
||||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
//2) 校验用户表单
|
||||||
if err = validateUserForm(req, aiModel); err != nil {
|
if err = validateUserForm(req, aiModel); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
//3) 处理不同类型
|
||||||
switch req.BuildType {
|
switch req.BuildType {
|
||||||
case public.BuildTypePrompt:
|
case public.BuildTypePrompt:
|
||||||
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
||||||
@@ -54,7 +57,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
|||||||
if chatModel == nil {
|
if chatModel == nil {
|
||||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||||
}
|
}
|
||||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
})
|
})
|
||||||
@@ -62,10 +65,10 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if aiModel == nil {
|
if aiModels == nil {
|
||||||
return nil, nil, errors.New("需要构建的模型不存在")
|
return nil, nil, errors.New("需要构建的模型不存在")
|
||||||
}
|
}
|
||||||
return chatModel, aiModel, nil
|
return chatModel, aiModels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateUserForm 校验用户表单
|
// validateUserForm 校验用户表单
|
||||||
@@ -150,13 +153,16 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
||||||
}
|
}
|
||||||
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
id := int64(0)
|
||||||
|
if req.SessionId != "" {
|
||||||
|
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
SessionId: req.SessionId,
|
SessionId: req.SessionId,
|
||||||
RequestContent: util.GetUserMessage(taskReq),
|
RequestContent: util.GetUserMessage(taskReq),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
|
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user