From 15f5761000467b41e5012098f47bc905a66d4816 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Thu, 21 May 2026 10:53:58 +0800 Subject: [PATCH] =?UTF-8?q?refactor(prompt):=20=E9=87=8D=E6=9E=84=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E6=A8=A1=E5=9E=8B=E5=AD=97=E6=AE=B5=E5=92=8C=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E6=9E=84=E5=BB=BA=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/config.go | 11 ++ common/util/network.go | 130 +++++++++++++++++++++++ dao/compose_session_dao.go | 4 +- model/entity/asynch_model.go | 108 +++++++++---------- service/prompt/prompt_build_service.go | 41 ++----- service/prompt/prompt_compose_service.go | 117 ++++++++++---------- service/prompt/prompt_ir_service.go | 6 +- 7 files changed, 266 insertions(+), 151 deletions(-) create mode 100644 common/util/network.go diff --git a/common/util/config.go b/common/util/config.go index 417de8a..4f27f97 100644 --- a/common/util/config.go +++ b/common/util/config.go @@ -2,11 +2,22 @@ package util import ( "context" + "strings" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) +// GetServerPort 从配置获取服务端口 +func GetServerPort(ctx context.Context) string { + address := g.Cfg().MustGet(ctx, "server.address", ":8080").String() + // address 格式如 ":3009",去掉冒号 + if strings.HasPrefix(address, ":") { + return address[1:] + } + return "8080" +} + // GetModelPrompt 获取请求模型的提示词 func GetModelPrompt(ctx context.Context, modelType int) string { key := "modelPrompts.types." + gconv.String(modelType) diff --git a/common/util/network.go b/common/util/network.go new file mode 100644 index 0000000..62acd7b --- /dev/null +++ b/common/util/network.go @@ -0,0 +1,130 @@ +package util + +import ( + "context" + "net" + "strings" + + "github.com/gogf/gf/v2/frame/g" +) + +// GetLocalIP 获取本机有效的局域网 IPv4 地址 +func GetLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "127.0.0.1" + } + + var validIPs []string + + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + ip := ipnet.IP + + if isIPValid(ip) { + validIPs = append(validIPs, ip.String()) + } + } + + // 优先返回非 169.254.x.x 的 IP + for _, ip := range validIPs { + if !strings.HasPrefix(ip, "169.254.") { + return ip + } + } + + // 其次返回 169.254.x.x(最后的选择) + if len(validIPs) > 0 { + return validIPs[0] + } + + return "127.0.0.1" +} + +// isIPValid 判断 IP 是否有效 +func isIPValid(ip net.IP) bool { + // 不是 loopback (127.0.0.1) + if ip.IsLoopback() { + return false + } + + // 是 IPv4 + if ip.To4() == nil { + return false + } + + // 不是链路本地地址 (169.254.0.0/16) + if ip[0] == 169 && ip[1] == 254 { + return false + } + + // 不是组播地址 + if ip.IsMulticast() { + return false + } + + // 不是未指定地址 (0.0.0.0) + if ip.IsUnspecified() { + return false + } + + return true +} + +// GetLocalAddress 获取局域网地址(IP:端口) +func GetLocalAddress(ctx context.Context) string { + ip := GetLocalIP() + port := GetServerPort(ctx) + + if port == "80" || port == "443" { + return ip + } + return ip + ":" + port +} + +// GetSchemaFromRequest 从当前请求中获取协议(http/https) +func GetSchemaFromRequest(ctx context.Context) string { + r := g.RequestFromCtx(ctx) + if r == nil { + return "http" + } + + // 1. 代理场景:X-Forwarded-Proto + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + + // 2. 代理场景:X-Forwarded-Scheme + if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" { + return proto + } + + // 3. TLS 连接(直接 HTTPS) + if r.TLS != nil { + return "https" + } + + // 4. 默认 HTTP(这行很重要!) + return "http" // ← 确保有这行 +} + +// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口) +func GetLocalBaseURL(ctx context.Context) string { + schema := GetSchemaFromRequest(ctx) + addr := GetLocalAddress(ctx) + return schema + "://" + addr +} + +// GetCallbackURL 获取回调地址(完整 URL) +func GetCallbackURL(ctx context.Context, path string) string { + baseURL := GetLocalBaseURL(ctx) + // 确保 path 以 / 开头 + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return baseURL + path +} diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index 4477e29..cdd4caf 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -75,13 +75,13 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, return nil, err } if r.IsEmpty() { - return nil, nil + return } err = r.Struct(&m) return } -// Delete 软删除编排会话 +// Delete 删除编排会话 func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). OmitEmpty(). diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index 3232cbf..0166240 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -14,7 +14,7 @@ type AsynchModel struct { RequestMapping any `orm:"request_mapping" json:"requestMapping"` ResponseMapping any `orm:"response_mapping" json:"responseMapping"` ResponseBody any `orm:"response_body" json:"responseBody"` - TokenMapping string `orm:"token_mapping" json:"tokenMapping"` + ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` Prompt string `orm:"prompt" json:"prompt"` IsPrivate *int `orm:"is_private" json:"isPrivate"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` @@ -35,60 +35,60 @@ type AsynchModel struct { type asynchModelCol struct { beans.SQLBaseCol - ModelName string - ModelType string - BaseURL string - HttpMethod string - HeadMsg string - FormJSON string - RequestMapping string - ResponseMapping string - ResponseBody string - TokenMapping string - Prompt string - IsPrivate string - IsChatModel string - ApiKey string - Enabled string - MaxConcurrency string - QueueLimit string - TimeoutSeconds string - ExpectedSeconds string - RetryTimes string - RetryQueueMaxSecs string - AutoCleanSeconds string - Remark string - IsOwner string - OperatorName string - TokenConfig string + ModelName string + ModelType string + BaseURL string + HttpMethod string + HeadMsg string + FormJSON string + RequestMapping string + ResponseMapping string + ResponseBody string + ResponseTokenField string + Prompt string + IsPrivate string + IsChatModel string + ApiKey string + Enabled string + MaxConcurrency string + QueueLimit string + TimeoutSeconds string + ExpectedSeconds string + RetryTimes string + RetryQueueMaxSecs string + AutoCleanSeconds string + Remark string + IsOwner string + OperatorName string + TokenConfig string } var AsynchModelCol = asynchModelCol{ - SQLBaseCol: beans.DefSQLBaseCol, - ModelName: "model_name", - ModelType: "model_type", - BaseURL: "base_url", - HttpMethod: "http_method", - HeadMsg: "head_msg", - FormJSON: "form_json", - RequestMapping: "request_mapping", - ResponseMapping: "response_mapping", - ResponseBody: "response_body", - TokenMapping: "token_mapping", - Prompt: "prompt", - IsPrivate: "is_private", - IsChatModel: "is_chat_model", - ApiKey: "api_key", - Enabled: "enabled", - MaxConcurrency: "max_concurrency", - QueueLimit: "queue_limit", - TimeoutSeconds: "timeout_seconds", - ExpectedSeconds: "expected_seconds", - RetryTimes: "retry_times", - RetryQueueMaxSecs: "retry_queue_max_seconds", - AutoCleanSeconds: "auto_clean_seconds", - Remark: "remark", - IsOwner: "is_owner", - OperatorName: "operator_name", - TokenConfig: "token_config", + SQLBaseCol: beans.DefSQLBaseCol, + ModelName: "model_name", + ModelType: "model_type", + BaseURL: "base_url", + HttpMethod: "http_method", + HeadMsg: "head_msg", + FormJSON: "form_json", + RequestMapping: "request_mapping", + ResponseMapping: "response_mapping", + ResponseBody: "response_body", + ResponseTokenField: "response_token_field", + Prompt: "prompt", + IsPrivate: "is_private", + IsChatModel: "is_chat_model", + ApiKey: "api_key", + Enabled: "enabled", + MaxConcurrency: "max_concurrency", + QueueLimit: "queue_limit", + TimeoutSeconds: "timeout_seconds", + ExpectedSeconds: "expected_seconds", + RetryTimes: "retry_times", + RetryQueueMaxSecs: "retry_queue_max_seconds", + AutoCleanSeconds: "auto_clean_seconds", + Remark: "remark", + IsOwner: "is_owner", + OperatorName: "operator_name", + TokenConfig: "token_config", } diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index bd62a98..0859454 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -26,16 +26,16 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha switch req.BuildType { case public.BuildTypePrompt: - return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches) + return buildPromptTypeRequest(ctx, processedReq, targetModel, chatModel, history, ir, totalBatches) case public.BuildTypeNode: - return buildNodeTypeRequest(ctx, req, ir) + return buildNodeTypeRequest(ctx, req, chatModel, ir) default: return nil, errors.New("不支持的构建类型") } } // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) -func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { +func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches) ir.AddSystem(systemPrompt) @@ -49,42 +49,23 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ta userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType)) ir.AddUser(userPrompt) - if !checkOverallContent(ir, targetModel) { availableWindow := util.GetAvailableWindow(targetModel.TokenConfig) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) } - return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel) + return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel.ModelName, chatModel) } // buildNodeTypeRequest 构建节点类型请求(BuildType=2) -func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) { +func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) - protocol, err := GetProtocolByProvider(ctx, req.ModelName) - if err != nil { - return nil, fmt.Errorf("获取协议配置失败: %w", err) - } - if protocol == nil { - return nil, errors.New("协议配置不存在") - } - - providerReq, err := Compile(ir, protocol, nil) - if err != nil { - return nil, fmt.Errorf("编译请求失败: %w", err) - } - - return map[string]any{ - "modelName": req.ModelName, - "bizName": "prompts-core", - "callbackUrl": "/prompt/callback", - "requestPayload": providerReq, - }, nil + return compileToProviderRequest(ctx, ir, req.ModelName, req.ModelName, chatModel) } // compileToProviderRequest 编译为 Provider 请求 -func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) { +func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, modelName string, chatModel *entity.AsynchModel) (map[string]any, error) { protocol, err := GetProtocolByProvider(ctx, providerName) if err != nil { return nil, fmt.Errorf("获取协议配置失败: %w", err) @@ -92,17 +73,15 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName st if protocol == nil { return nil, errors.New("协议配置不存在") } - - providerReq, err := Compile(ir, protocol, model) + providerReq, err := Compile(ir, protocol, chatModel) if err != nil { return nil, fmt.Errorf("编译请求失败: %w", err) } - fmt.Println("providerReq打印:", util.MustMarshal(providerReq)) return map[string]any{ - "modelName": model.ModelName, + "modelName": modelName, "bizName": "prompts-core", - "callbackUrl": "/prompt/callback", + "callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"), "requestPayload": providerReq, }, nil } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 69dc494..ef9a19f 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -30,6 +30,7 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com if err = validateUserForm(ctx, req, aiModel); err != nil { return nil, err } + fmt.Printf("req打印%+v", req) switch req.BuildType { case public.BuildTypePrompt: return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建 @@ -85,13 +86,13 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) continue } - + //等待结果 taskRecord, err = waitForResult(ctx, taskID) if err != nil { g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) continue } - + //处理结果 message = parsePromptBuild(taskRecord, chatModel) if message != nil { break @@ -244,93 +245,87 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond - deadline := time.Now().Add(timeout) - ticker := time.NewTicker(pollInterval) - defer ticker.Stop() for { + // ===================== 修复点 1:检查上下文是否取消 ===================== + select { + case <-ctx.Done(): + // 请求已被取消,直接返回,不继续查库 + return nil, ctx.Err() + default: + } + + // 1. 查数据库 record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: taskID, }) if err != nil { + // ===================== 修复点 2:如果是上下文取消,直接返回 ===================== if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - return nil, fmt.Errorf("查询任务失败: %w", err) + return nil, err } - if record != nil { - if completed, result := checkTaskCompletion(record); completed { - return result, nil + switch record.Status { + case public.ComposeStatusSuccess: + return record, nil + case public.ComposeStatusFailed: + if strings.TrimSpace(record.ErrorMessage) == "" { + return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) + } + return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) } } - if err = syncGatewayTaskState(ctx, taskID, record); err != nil { - g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err) + // 2. 查网关状态 + state, err := gateway.QueryGatewayTaskState(ctx, taskID) + if err != nil { + // 网关不可达不终止,继续轮询 + g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) + } else { + switch state { + case 2: // 网关成功 + // 网关已成功,主动更新数据库 + if record != nil { + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: taskID, + Status: public.ComposeStatusSuccess, + }) + if err != nil { + g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) + } + } + case 3: // 网关失败 + if record != nil { + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: taskID, + Status: public.ComposeStatusFailed, + ErrorMessage: "model-gateway 任务执行失败", + }) + if err != nil { + g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) + } + } + return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) + } } + // 3. 超时检查 if time.Now().After(deadline) { return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) } + // ===================== 修复点3:sleep 也要监听 ctx 取消 ===================== select { case <-ctx.Done(): return nil, ctx.Err() - case <-ticker.C: + case <-time.After(pollInterval): } } } -// checkTaskCompletion 检查任务是否完成 -func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) { - if record == nil { - return false, nil - } - switch record.Status { - case public.ComposeStatusSuccess: - return true, record - case public.ComposeStatusFailed: - errMsg := strings.TrimSpace(record.ErrorMessage) - if errMsg == "" { - return true, nil - } - return true, nil - default: - return false, nil - } -} - -// syncGatewayTaskState 同步网关任务状态 -func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error { - state, err := gateway.QueryGatewayTaskState(ctx, taskID) - if err != nil { - return fmt.Errorf("查询网关状态失败: %w", err) - } - switch state { - case 2: - return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "") - case 3: - updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败") - return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) - } - return nil -} - -// updateTaskStatus 更新任务状态 -func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error { - task := &entity.ComposeTask{ - TaskId: taskID, - Status: status, - } - if errorMsg != "" { - task.ErrorMessage = errorMsg - } - - _, err := dao.ComposeTask.Update(ctx, task) - return err -} - // parsePromptBuild 解析提示词构建结果(BuildType == 1) func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult { if taskRecord == nil { diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go index 33d22b3..d9d554e 100644 --- a/service/prompt/prompt_ir_service.go +++ b/service/prompt/prompt_ir_service.go @@ -159,7 +159,6 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP if err != nil || entity == nil { return nil, err } - fmt.Println("entity打印", entity) return parseProtocol(entity), nil } @@ -183,7 +182,6 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) ( if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } - messages := mergeByOrder(ir, p.MergeOrder) messages = mapRoles(messages, p.RoleMapping) messages = mapContent(messages, p.ContentMapping) @@ -279,7 +277,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e b, _ := json.Marshal(tmpl) str := string(b) - str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) + if chatModel != nil { + str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) + } msgBytes, _ := json.Marshal(messages) str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))