refactor(prompt): 重构提示词构建服务和回调处理
This commit is contained in:
@@ -22,8 +22,7 @@ type ConsultItem struct {
|
|||||||
Url string `json:"url" dc:"附件地址"`
|
Url string `json:"url" dc:"附件地址"`
|
||||||
}
|
}
|
||||||
type ComposeMessagesRes struct {
|
type ComposeMessagesRes struct {
|
||||||
TaskId string `json:"taskId" dc:"任务ID"`
|
TaskId string `json:"taskId" dc:"任务ID"`
|
||||||
EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiRoundResult 多轮返回结果
|
// MultiRoundResult 多轮返回结果
|
||||||
|
|||||||
@@ -156,17 +156,18 @@ type SendCallbackReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendCallback 向业务方发送回调
|
// SendCallback 向业务方发送回调
|
||||||
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
|
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
|
||||||
// 1. 检查回调地址
|
// 1. 检查回调地址
|
||||||
if composeTask.CallbackUrl == "" {
|
if composeTask.CallbackUrl == "" {
|
||||||
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
||||||
}
|
}
|
||||||
// 2. 构造请求体
|
// 2. 构造请求体
|
||||||
req := SendCallbackReq{
|
req := SendCallbackReq{
|
||||||
TaskId: composeTask.TaskId,
|
TaskId: composeTask.TaskId,
|
||||||
Status: composeTask.Status,
|
Status: composeTask.Status,
|
||||||
Messages: composeTask.Messages,
|
Messages: composeTask.Messages,
|
||||||
ErrorMsg: composeTask.ErrorMessage,
|
ErrorMsg: composeTask.ErrorMessage,
|
||||||
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
// 3. 发送 POST 请求
|
// 3. 发送 POST 请求
|
||||||
headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/service/gateway"
|
"prompts-core/service/gateway"
|
||||||
"prompts-core/service/session"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
@@ -17,34 +16,14 @@ import (
|
|||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserPromptPayload 用户提示词请求体
|
|
||||||
type UserPromptPayload struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
PromptInfo string `json:"promptInfo"`
|
|
||||||
Form any `json:"form"`
|
|
||||||
UserForm any `json:"userForm"`
|
|
||||||
Consult []dto.ConsultItem `json:"consult"`
|
|
||||||
UserFilesText map[string]string `json:"userFilesText"`
|
|
||||||
Skills string `json:"skills"`
|
|
||||||
BuildType int `json:"buildType"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||||
//1) 构建系统提示词
|
//1) 构建系统提示词
|
||||||
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
//2) 构建历史对话
|
|
||||||
history, _ := session.GetHistoryMessages(ctx, req.SessionId)
|
|
||||||
for _, msg := range history {
|
|
||||||
role := gconv.String(msg["role"])
|
|
||||||
if role != "user" && role != "assistant" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
|
||||||
}
|
|
||||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||||||
ir.AddUser(userPrompt)
|
ir.AddUser(userPrompt)
|
||||||
|
//2) 检查整体内容是否超出窗口
|
||||||
if !checkOverallContent(ir, aiModel) {
|
if !checkOverallContent(ir, aiModel) {
|
||||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||||
@@ -96,8 +75,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// promptBuildWithRounds 构建系统提示词
|
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
|
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: chatModel.OperatorName,
|
ProviderName: chatModel.OperatorName,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
@@ -105,32 +83,9 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, cha
|
|||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||||
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
|
|
||||||
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
|
|
||||||
formContent := buildUserFormContent(req.Form)
|
|
||||||
userFormContent := buildUserFormContent(req.UserForm)
|
|
||||||
formInfo := fmt.Sprintf(`
|
|
||||||
【系统表单(系统提示词/参数)】
|
|
||||||
%s
|
|
||||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
|
||||||
%s
|
|
||||||
`, formContent, userFormContent)
|
|
||||||
|
|
||||||
inputInfo := fmt.Sprintf(`
|
|
||||||
目标模型: %s
|
|
||||||
%s
|
|
||||||
技能名称: %s
|
|
||||||
用户文件: %v
|
|
||||||
`, req.ModelName, formInfo, req.SkillName, req.Consult)
|
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||||
req.ModelName, // %s 目标模型名称
|
outputJSON, //【输出结构】 %s
|
||||||
maxWindowSize, // %d 最大窗口
|
|
||||||
availableWindow, // %d 可用窗口
|
|
||||||
outputJSON, // %s 输出结构
|
|
||||||
inputInfo, // %s 完整输入信息
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,43 +106,67 @@ func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
|||||||
|
|
||||||
// buildUserPrompt 构建用户提示词
|
// buildUserPrompt 构建用户提示词
|
||||||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||||||
payload := UserPromptPayload{
|
var b strings.Builder
|
||||||
Model: req.ModelName,
|
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
|
||||||
PromptInfo: prompt,
|
if prompt != "" {
|
||||||
Form: prepareUserFormPayload(req.Form),
|
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
|
||||||
UserForm: prepareUserFormPayload(req.UserForm),
|
|
||||||
Consult: req.Consult,
|
|
||||||
UserFilesText: ExtractFileTexts(ctx, req.Consult),
|
|
||||||
Skills: SkillMdContent(ctx, req.SkillName),
|
|
||||||
BuildType: req.BuildType,
|
|
||||||
}
|
}
|
||||||
return gjson.New(payload).String()
|
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
|
||||||
|
}
|
||||||
|
if formText := buildUserFormText(req.Form); formText != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
||||||
|
}
|
||||||
|
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
||||||
|
}
|
||||||
|
if len(req.Consult) > 0 {
|
||||||
|
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
|
||||||
|
}
|
||||||
|
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
||||||
|
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareUserFormPayload 准备用户表单载荷
|
// buildUserFormText 构建用户表单内容字符串
|
||||||
func prepareUserFormPayload(userForm []map[string]any) any {
|
func buildUserFormText(form []map[string]any) string {
|
||||||
if len(userForm) == 0 {
|
if len(form) == 0 {
|
||||||
return nil
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := userForm[0]["batch_index"]; ok {
|
|
||||||
return userForm
|
|
||||||
}
|
|
||||||
|
|
||||||
return mergeUserFormTexts(userForm)
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
|
|
||||||
func mergeUserFormTexts(userForm []map[string]any) string {
|
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for i, item := range userForm {
|
for _, item := range form {
|
||||||
text := getItemText(item)
|
for k, v := range item {
|
||||||
if i > 0 {
|
switch val := v.(type) {
|
||||||
builder.WriteString("\n\n")
|
case []any:
|
||||||
|
// 数组类型:逐条列出
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||||
|
for i, elem := range val {
|
||||||
|
if m, ok := elem.(map[string]any); ok {
|
||||||
|
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||||
|
for mk, mv := range m {
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||||
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
} else {
|
||||||
|
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []map[string]any:
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||||
|
for i, m := range val {
|
||||||
|
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||||
|
for mk, mv := range m {
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||||
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:%v\n", k, v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
builder.WriteString(text)
|
|
||||||
}
|
}
|
||||||
return builder.String()
|
return strings.TrimSpace(builder.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// NodeBuild 节点构建
|
// NodeBuild 节点构建
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"gitea.com/red-future/common/beans"
|
"gitea.com/red-future/common/beans"
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ComposeMessages 核心拼接提示词主流程
|
// ComposeMessages 核心拼接提示词主流程
|
||||||
@@ -157,14 +158,14 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusFailed
|
composeTask.Status = public.ComposeStatusFailed
|
||||||
composeTask.ErrorMessage = req.ErrorMsg
|
composeTask.ErrorMessage = req.ErrorMsg
|
||||||
_ = gateway.SendCallback(ctx, composeTask)
|
_ = gateway.SendCallback(ctx, composeTask, 0)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCallbackSuccess 处理回调成功
|
// handleCallbackSuccess 处理回调成功
|
||||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||||
// 1) 查模型配置
|
// 1) 获取模型配置
|
||||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||||
ModelName: composeTask.ModelName,
|
ModelName: composeTask.ModelName,
|
||||||
@@ -172,6 +173,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("查询模型失败: %w", err)
|
return fmt.Errorf("查询模型失败: %w", err)
|
||||||
}
|
}
|
||||||
|
// 2) 根据运营商获取协议配置
|
||||||
|
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
|
// ProviderName: model.OperatorName,
|
||||||
|
//})
|
||||||
|
|
||||||
// 2) 解析结果
|
// 2) 解析结果
|
||||||
var messages map[string]any
|
var messages map[string]any
|
||||||
@@ -179,10 +184,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||||
case public.BuildTypeStruct:
|
case public.BuildTypeStruct:
|
||||||
messages = map[string]any{
|
messages = ParseStructResult(req.Messages, model.ResponseBody)
|
||||||
"total_rounds": 1,
|
|
||||||
"rounds": []map[string]any{req.Messages},
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
messages = req.Messages
|
messages = req.Messages
|
||||||
}
|
}
|
||||||
@@ -201,80 +203,113 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// 5) 存储历史结果
|
// 5) 存储提示词结果作为历史请求
|
||||||
|
var epicycleId int64
|
||||||
// 6) 回调业务方
|
payload := composeTask.RequestPayload
|
||||||
|
sessionId := gconv.String(payload["sessionId"])
|
||||||
|
nodeId := gconv.String(payload["nodeId"])
|
||||||
|
buildType := gconv.Int(payload["buildType"])
|
||||||
|
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||||
|
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
|
NodeId: nodeId,
|
||||||
|
SessionId: sessionId,
|
||||||
|
RequestContent: messages,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// 6) 拼接历史内容
|
||||||
|
// 7) 回调业务方
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusSuccess
|
composeTask.Status = public.ComposeStatusSuccess
|
||||||
composeTask.Messages = messages
|
composeTask.Messages = messages
|
||||||
_ = gateway.SendCallback(ctx, composeTask)
|
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseResult 解析回调结果
|
// ParseResult 解析结果
|
||||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||||
// responseBody 为空,直接返回原始数据
|
|
||||||
if responseBody == "" {
|
if responseBody == "" {
|
||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按 responseBody 路径取值
|
contentVal := raw[responseBody]
|
||||||
contentStr, ok := raw[responseBody].(string)
|
if contentVal == nil {
|
||||||
if !ok || contentStr == "" {
|
|
||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
// 已经是数组
|
||||||
return map[string]any{
|
if arr, ok := contentVal.([]any); ok {
|
||||||
"total_rounds": len(roundsArray),
|
rounds := gconv.Maps(arr)
|
||||||
"rounds": roundsArray,
|
if len(rounds) > 0 {
|
||||||
|
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
|
||||||
}
|
}
|
||||||
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
// 是字符串
|
||||||
return map[string]any{
|
contentStr := gconv.String(contentVal)
|
||||||
"total_rounds": 1,
|
if contentStr == "" {
|
||||||
"rounds": []map[string]any{singleRound},
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 尝试解析为数组
|
||||||
|
var arr []map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||||
|
return map[string]any{"total_rounds": len(arr), "rounds": arr}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为单对象
|
||||||
|
var obj map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
|
||||||
|
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]any{"content": contentStr}
|
return map[string]any{"content": contentStr}
|
||||||
}
|
}
|
||||||
|
|
||||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||||
var arr []map[string]any
|
// 如果外层已有 rounds,直接返回
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
|
if _, ok := raw["rounds"]; ok {
|
||||||
return nil
|
return raw
|
||||||
}
|
}
|
||||||
if len(arr) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return arr
|
|
||||||
}
|
|
||||||
|
|
||||||
func tryParseAsMap(jsonStr string) map[string]any {
|
contentVal := raw[responseBody]
|
||||||
var obj map[string]any
|
if contentVal == nil {
|
||||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
return map[string]any{
|
||||||
return nil
|
"total_rounds": 1,
|
||||||
}
|
"rounds": []map[string]any{raw},
|
||||||
if len(obj) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseNodeResult(raw map[string]any) map[string]any {
|
|
||||||
contentStr, ok := raw["content"].(string)
|
|
||||||
if ok && contentStr != "" {
|
|
||||||
var inner map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
|
|
||||||
return map[string]any{
|
|
||||||
"total_rounds": 1,
|
|
||||||
"rounds": []map[string]any{inner},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
contentStr := gconv.String(contentVal)
|
||||||
|
if contentStr == "" || contentStr == "0" {
|
||||||
|
return map[string]any{
|
||||||
|
"total_rounds": 1,
|
||||||
|
"rounds": []map[string]any{raw},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为数组
|
||||||
|
var arr []map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||||
|
// 数组的每个元素包一层 content
|
||||||
|
var rounds []map[string]any
|
||||||
|
for _, item := range arr {
|
||||||
|
rounds = append(rounds, map[string]any{"content": item})
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"total_rounds": len(rounds),
|
||||||
|
"rounds": rounds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析为单个对象
|
||||||
|
var parsed map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||||
|
raw[responseBody] = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
// 兜底:包标准结构
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"total_rounds": 1,
|
"total_rounds": 1,
|
||||||
"rounds": []map[string]any{raw},
|
"rounds": []map[string]any{raw},
|
||||||
|
|||||||
@@ -22,26 +22,25 @@ const (
|
|||||||
bytesPerMB = 1024 * 1024
|
bytesPerMB = 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容
|
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
|
||||||
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) map[string]string {
|
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
|
||||||
urls := make([]string, 0, len(consult))
|
urls := make([]string, 0, len(consult))
|
||||||
for _, item := range consult {
|
for _, item := range consult {
|
||||||
if item.Url != "" {
|
if item.Url != "" {
|
||||||
urls = append(urls, item.Url)
|
urls = append(urls, item.Url)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return FetchFileTexts(ctx, urls)
|
return FetchFileTextsAsString(ctx, urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件
|
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
|
||||||
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
|
||||||
result := make(map[string]string)
|
|
||||||
|
|
||||||
if len(urls) == 0 {
|
if len(urls) == 0 {
|
||||||
return result
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, rawURL := range urls {
|
for _, rawURL := range urls {
|
||||||
url := util.SanitizeURL(rawURL)
|
url := util.SanitizeURL(rawURL)
|
||||||
@@ -50,23 +49,19 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if util.IsZipExtension(url) {
|
if util.IsZipExtension(url) {
|
||||||
mergeMap(result, fetchZipFileTexts(ctx, client, url))
|
for _, text := range fetchZipFileTexts(ctx, client, url) {
|
||||||
|
builder.WriteString(text)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
|
||||||
result[url] = text
|
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return builder.String()
|
||||||
}
|
|
||||||
|
|
||||||
// mergeMap 合并 map
|
|
||||||
func mergeMap(dst, src map[string]string) {
|
|
||||||
for k, v := range src {
|
|
||||||
dst[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchAndCleanFileContent 获取并清理文件内容
|
// fetchAndCleanFileContent 获取并清理文件内容
|
||||||
|
|||||||
Reference in New Issue
Block a user