refactor(files): 优化文件处理和任务服务逻辑
This commit is contained in:
@@ -1,81 +0,0 @@
|
|||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
|
||||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
|
||||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
consult := gconv.Interfaces(req["consult"])
|
|
||||||
if len(consult) == 0 {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
|
||||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
|
||||||
if targetPath == "" || len(templates) == 0 {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
msgJson := gjson.New(messages)
|
|
||||||
|
|
||||||
// rounds 路径修正
|
|
||||||
if !msgJson.Get("rounds.0").IsNil() {
|
|
||||||
targetPath = "rounds.0." + targetPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// 遍历追加
|
|
||||||
for _, item := range consult {
|
|
||||||
itemJson := gjson.New(item)
|
|
||||||
itemType := itemJson.Get("type").String()
|
|
||||||
tmpl := gconv.Map(templates[itemType])
|
|
||||||
if itemType == "" || len(tmpl) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
|
||||||
if attachment == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := len(msgJson.Get(targetPath).Array())
|
|
||||||
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
|
||||||
}
|
|
||||||
|
|
||||||
return msgJson.Map()
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
|
||||||
typ := gconv.String(tmpl["type"])
|
|
||||||
if typ == "" || url == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
body := gconv.Map(tmpl["body"])
|
|
||||||
fillEmptyInPlace(body, url)
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"type": typ,
|
|
||||||
typ: body,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fillEmptyInPlace(m map[string]any, value string) {
|
|
||||||
for k, v := range m {
|
|
||||||
switch vv := v.(type) {
|
|
||||||
case string:
|
|
||||||
if vv == "" {
|
|
||||||
m[k] = value
|
|
||||||
}
|
|
||||||
case map[string]any:
|
|
||||||
fillEmptyInPlace(vv, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,16 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReverseMap 映射 payload 到 mapping
|
// ======================== 请求映射 ========================
|
||||||
|
|
||||||
|
// ReverseMap 将 payload 按 mapping 路径映射为嵌套结构
|
||||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||||
jsonObj := gjson.New("{}")
|
jsonObj := gjson.New("{}")
|
||||||
for path, defaultValue := range mapping {
|
for path, defaultValue := range mapping {
|
||||||
@@ -21,6 +24,8 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
|||||||
return jsonObj.Map()
|
return jsonObj.Map()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ======================== 用户文本提取 ========================
|
||||||
|
|
||||||
// ExtractUserText 从 messages 中提取所有 user 文本
|
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||||
func ExtractUserText(messages map[string]any) map[string]any {
|
func ExtractUserText(messages map[string]any) map[string]any {
|
||||||
msgJson := gjson.New(messages)
|
msgJson := gjson.New(messages)
|
||||||
@@ -29,6 +34,7 @@ func ExtractUserText(messages map[string]any) map[string]any {
|
|||||||
if msgs.IsNil() {
|
if msgs.IsNil() {
|
||||||
msgs = msgJson.Get("messages")
|
msgs = msgJson.Get("messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
var texts []string
|
var texts []string
|
||||||
for _, m := range msgs.Array() {
|
for _, m := range msgs.Array() {
|
||||||
msg := gjson.New(m)
|
msg := gjson.New(m)
|
||||||
@@ -55,3 +61,128 @@ func ExtractUserText(messages map[string]any) map[string]any {
|
|||||||
"content": strings.Join(texts, "\n"),
|
"content": strings.Join(texts, "\n"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ======================== 附件合并 ========================
|
||||||
|
|
||||||
|
// MergeConsult 将 consult 附件合并到每个 round 的 content 数组中
|
||||||
|
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||||
|
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
consult := gconv.Interfaces(req["consult"])
|
||||||
|
if len(consult) == 0 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||||
|
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||||
|
if targetPath == "" || len(templates) == 0 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
msgJson := gjson.New(messages)
|
||||||
|
|
||||||
|
rounds := msgJson.Get("rounds").Array()
|
||||||
|
for i := range rounds {
|
||||||
|
roundPath := fmt.Sprintf("rounds.%d.%s", i, targetPath)
|
||||||
|
for _, item := range consult {
|
||||||
|
itemJson := gjson.New(item)
|
||||||
|
itemType := itemJson.Get("type").String()
|
||||||
|
tmpl := gconv.Map(templates[itemType])
|
||||||
|
if itemType == "" || len(tmpl) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||||
|
if attachment == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := len(msgJson.Get(roundPath).Array())
|
||||||
|
_ = msgJson.Set(fmt.Sprintf("%s.%d", roundPath, idx), attachment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgJson.Map()
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAttachment 根据模板和 url 生成附件对象
|
||||||
|
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||||
|
typ := gconv.String(tmpl["type"])
|
||||||
|
if typ == "" || url == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body := gconv.Map(tmpl["body"])
|
||||||
|
fillEmptyInPlace(body, url)
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"type": typ,
|
||||||
|
typ: body,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillEmptyInPlace 递归填充空字符串
|
||||||
|
func fillEmptyInPlace(m map[string]any, value string) {
|
||||||
|
for k, v := range m {
|
||||||
|
switch vv := v.(type) {
|
||||||
|
case string:
|
||||||
|
if vv == "" {
|
||||||
|
m[k] = value
|
||||||
|
}
|
||||||
|
case map[string]any:
|
||||||
|
fillEmptyInPlace(vv, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================== 系统提示词合并 ========================
|
||||||
|
|
||||||
|
// MergeSystemPrompt 将系统提示词和技能内容拼接到 system role 的 content 中
|
||||||
|
func MergeSystemPrompt(messages map[string]any, prompt, skills string, requestMapping map[string]any) map[string]any {
|
||||||
|
var parts []string
|
||||||
|
if prompt != "" {
|
||||||
|
parts = append(parts, prompt)
|
||||||
|
}
|
||||||
|
if skills != "" {
|
||||||
|
parts = append(parts, skills)
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
systemContent := strings.Join(parts, "\n")
|
||||||
|
systemPath := getSystemPromptPath(requestMapping)
|
||||||
|
if systemPath == "" {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
msgJson := gjson.New(messages)
|
||||||
|
|
||||||
|
existing := msgJson.Get(systemPath).String()
|
||||||
|
if existing != "" {
|
||||||
|
systemContent = existing + "\n" + systemContent
|
||||||
|
}
|
||||||
|
_ = msgJson.Set(systemPath, systemContent)
|
||||||
|
|
||||||
|
return msgJson.Map()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemPromptPath 从 RequestMapping 中提取 system content 的路径
|
||||||
|
func getSystemPromptPath(requestMapping map[string]any) string {
|
||||||
|
for key, val := range requestMapping {
|
||||||
|
if !strings.Contains(key, ".role") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gconv.String(val) != "system" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prefix := strings.TrimSuffix(key, ".role")
|
||||||
|
contentKey := prefix + ".content"
|
||||||
|
if _, ok := requestMapping[contentKey]; ok {
|
||||||
|
return contentKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
|||||||
//1) 构建系统提示词
|
//1) 构建系统提示词
|
||||||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
userPrompt := buildUserPrompt(ctx, req)
|
||||||
ir.AddUser(userPrompt)
|
ir.AddUser(userPrompt)
|
||||||
//2) 检查整体内容是否超出窗口
|
//2) 检查整体内容是否超出窗口
|
||||||
if !checkOverallContent(ir, aiModel) {
|
if !checkOverallContent(ir, aiModel) {
|
||||||
@@ -40,7 +40,7 @@ func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chat
|
|||||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||||
ir.AddSystem(customPrompt)
|
ir.AddSystem(customPrompt)
|
||||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
ir.AddUser(buildUserPrompt(ctx, req))
|
||||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,9 +81,14 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
|||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonString()
|
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON)
|
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{
|
||||||
|
"model": aiModel.ModelName,
|
||||||
|
})).MustToJsonString()
|
||||||
|
|
||||||
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||||
|
outputJSON, //%s【输出结构】
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkOverallContent 检查整体内容是否超出窗口
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
@@ -93,15 +98,8 @@ func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildUserPrompt 构建用户提示词
|
// buildUserPrompt 构建用户提示词
|
||||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
|
|
||||||
if prompt != "" {
|
|
||||||
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
|
|
||||||
}
|
|
||||||
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
|
|
||||||
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
|
|
||||||
}
|
|
||||||
if formText := buildUserFormText(req.Form); formText != "" {
|
if formText := buildUserFormText(req.Form); formText != "" {
|
||||||
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -232,15 +232,19 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 4) 合并系统提示词
|
||||||
|
systemPrompt := util.GetModelPrompt(ctx, model.ModelType)
|
||||||
|
skillContent := SkillMdContent(ctx, composeTask.SkillName)
|
||||||
|
messages = util.MergeSystemPrompt(messages, systemPrompt, skillContent, model.RequestMapping)
|
||||||
|
|
||||||
// 4) 合并附加结构
|
// 5) 合并附加结构
|
||||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||||
// 5) 注入历史
|
// 6) 注入历史
|
||||||
if len(history) > 0 {
|
if len(history) > 0 {
|
||||||
messages = InjectHistory(messages, history, protocol)
|
messages = InjectHistory(messages, history, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6) 更新数据库
|
// 7) 更新数据库
|
||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
Status: public.ComposeStatusSuccess,
|
Status: public.ComposeStatusSuccess,
|
||||||
|
|||||||
Reference in New Issue
Block a user