refactor(prompt): 重构异步模型字段和提示词构建服务

This commit is contained in:
2026-05-21 10:53:58 +08:00
parent fee6528f93
commit 15f5761000
7 changed files with 266 additions and 151 deletions

View File

@@ -2,11 +2,22 @@ package util
import ( import (
"context" "context"
"strings"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
// address 格式如 ":3009",去掉冒号
if strings.HasPrefix(address, ":") {
return address[1:]
}
return "8080"
}
// GetModelPrompt 获取请求模型的提示词 // GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, modelType int) string { func GetModelPrompt(ctx context.Context, modelType int) string {
key := "modelPrompts.types." + gconv.String(modelType) key := "modelPrompts.types." + gconv.String(modelType)

130
common/util/network.go Normal file
View File

@@ -0,0 +1,130 @@
package util
import (
"context"
"net"
"strings"
"github.com/gogf/gf/v2/frame/g"
)
// GetLocalIP 获取本机有效的局域网 IPv4 地址
func GetLocalIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
var validIPs []string
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP
if isIPValid(ip) {
validIPs = append(validIPs, ip.String())
}
}
// 优先返回非 169.254.x.x 的 IP
for _, ip := range validIPs {
if !strings.HasPrefix(ip, "169.254.") {
return ip
}
}
// 其次返回 169.254.x.x最后的选择
if len(validIPs) > 0 {
return validIPs[0]
}
return "127.0.0.1"
}
// isIPValid 判断 IP 是否有效
func isIPValid(ip net.IP) bool {
// 不是 loopback (127.0.0.1)
if ip.IsLoopback() {
return false
}
// 是 IPv4
if ip.To4() == nil {
return false
}
// 不是链路本地地址 (169.254.0.0/16)
if ip[0] == 169 && ip[1] == 254 {
return false
}
// 不是组播地址
if ip.IsMulticast() {
return false
}
// 不是未指定地址 (0.0.0.0)
if ip.IsUnspecified() {
return false
}
return true
}
// GetLocalAddress 获取局域网地址IP:端口)
func GetLocalAddress(ctx context.Context) string {
ip := GetLocalIP()
port := GetServerPort(ctx)
if port == "80" || port == "443" {
return ip
}
return ip + ":" + port
}
// GetSchemaFromRequest 从当前请求中获取协议http/https
func GetSchemaFromRequest(ctx context.Context) string {
r := g.RequestFromCtx(ctx)
if r == nil {
return "http"
}
// 1. 代理场景X-Forwarded-Proto
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
// 2. 代理场景X-Forwarded-Scheme
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
return proto
}
// 3. TLS 连接(直接 HTTPS
if r.TLS != nil {
return "https"
}
// 4. 默认 HTTP这行很重要
return "http" // ← 确保有这行
}
// GetLocalBaseURL 获取局域网基础 URL动态协议 + IP + 端口)
func GetLocalBaseURL(ctx context.Context) string {
schema := GetSchemaFromRequest(ctx)
addr := GetLocalAddress(ctx)
return schema + "://" + addr
}
// GetCallbackURL 获取回调地址(完整 URL
func GetCallbackURL(ctx context.Context, path string) string {
baseURL := GetLocalBaseURL(ctx)
// 确保 path 以 / 开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}

View File

@@ -75,13 +75,13 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession,
return nil, err return nil, err
} }
if r.IsEmpty() { if r.IsEmpty() {
return nil, nil return
} }
err = r.Struct(&m) err = r.Struct(&m)
return return
} }
// Delete 删除编排会话 // Delete 删除编排会话
func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) { func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSession) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession).
OmitEmpty(). OmitEmpty().

View File

@@ -14,7 +14,7 @@ type AsynchModel struct {
RequestMapping any `orm:"request_mapping" json:"requestMapping"` RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"` ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"` ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"` ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
Prompt string `orm:"prompt" json:"prompt"` Prompt string `orm:"prompt" json:"prompt"`
IsPrivate *int `orm:"is_private" json:"isPrivate"` IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
@@ -35,60 +35,60 @@ type AsynchModel struct {
type asynchModelCol struct { type asynchModelCol struct {
beans.SQLBaseCol beans.SQLBaseCol
ModelName string ModelName string
ModelType string ModelType string
BaseURL string BaseURL string
HttpMethod string HttpMethod string
HeadMsg string HeadMsg string
FormJSON string FormJSON string
RequestMapping string RequestMapping string
ResponseMapping string ResponseMapping string
ResponseBody string ResponseBody string
TokenMapping string ResponseTokenField string
Prompt string Prompt string
IsPrivate string IsPrivate string
IsChatModel string IsChatModel string
ApiKey string ApiKey string
Enabled string Enabled string
MaxConcurrency string MaxConcurrency string
QueueLimit string QueueLimit string
TimeoutSeconds string TimeoutSeconds string
ExpectedSeconds string ExpectedSeconds string
RetryTimes string RetryTimes string
RetryQueueMaxSecs string RetryQueueMaxSecs string
AutoCleanSeconds string AutoCleanSeconds string
Remark string Remark string
IsOwner string IsOwner string
OperatorName string OperatorName string
TokenConfig string TokenConfig string
} }
var AsynchModelCol = asynchModelCol{ var AsynchModelCol = asynchModelCol{
SQLBaseCol: beans.DefSQLBaseCol, SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name", ModelName: "model_name",
ModelType: "model_type", ModelType: "model_type",
BaseURL: "base_url", BaseURL: "base_url",
HttpMethod: "http_method", HttpMethod: "http_method",
HeadMsg: "head_msg", HeadMsg: "head_msg",
FormJSON: "form_json", FormJSON: "form_json",
RequestMapping: "request_mapping", RequestMapping: "request_mapping",
ResponseMapping: "response_mapping", ResponseMapping: "response_mapping",
ResponseBody: "response_body", ResponseBody: "response_body",
TokenMapping: "token_mapping", ResponseTokenField: "response_token_field",
Prompt: "prompt", Prompt: "prompt",
IsPrivate: "is_private", IsPrivate: "is_private",
IsChatModel: "is_chat_model", IsChatModel: "is_chat_model",
ApiKey: "api_key", ApiKey: "api_key",
Enabled: "enabled", Enabled: "enabled",
MaxConcurrency: "max_concurrency", MaxConcurrency: "max_concurrency",
QueueLimit: "queue_limit", QueueLimit: "queue_limit",
TimeoutSeconds: "timeout_seconds", TimeoutSeconds: "timeout_seconds",
ExpectedSeconds: "expected_seconds", ExpectedSeconds: "expected_seconds",
RetryTimes: "retry_times", RetryTimes: "retry_times",
RetryQueueMaxSecs: "retry_queue_max_seconds", RetryQueueMaxSecs: "retry_queue_max_seconds",
AutoCleanSeconds: "auto_clean_seconds", AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark", Remark: "remark",
IsOwner: "is_owner", IsOwner: "is_owner",
OperatorName: "operator_name", OperatorName: "operator_name",
TokenConfig: "token_config", TokenConfig: "token_config",
} }

View File

@@ -26,16 +26,16 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
switch req.BuildType { switch req.BuildType {
case public.BuildTypePrompt: case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches) return buildPromptTypeRequest(ctx, processedReq, targetModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode: case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, ir) return buildNodeTypeRequest(ctx, req, chatModel, ir)
default: default:
return nil, errors.New("不支持的构建类型") return nil, errors.New("不支持的构建类型")
} }
} }
// buildPromptTypeRequest 构建提示词类型请求BuildType=1 // buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches) systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches)
ir.AddSystem(systemPrompt) ir.AddSystem(systemPrompt)
@@ -49,42 +49,23 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ta
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType)) userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType))
ir.AddUser(userPrompt) ir.AddUser(userPrompt)
if !checkOverallContent(ir, targetModel) { if !checkOverallContent(ir, targetModel) {
availableWindow := util.GetAvailableWindow(targetModel.TokenConfig) availableWindow := util.GetAvailableWindow(targetModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
} }
return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel) return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel.ModelName, chatModel)
} }
// buildNodeTypeRequest 构建节点类型请求BuildType=2 // buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) { func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req)) ir.AddUser(NodeBuild(ctx, req))
protocol, err := GetProtocolByProvider(ctx, req.ModelName) return compileToProviderRequest(ctx, ir, req.ModelName, req.ModelName, chatModel)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, nil)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": req.ModelName,
"bizName": "prompts-core",
"callbackUrl": "/prompt/callback",
"requestPayload": providerReq,
}, nil
} }
// compileToProviderRequest 编译为 Provider 请求 // compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) { func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, modelName string, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, providerName) protocol, err := GetProtocolByProvider(ctx, providerName)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err) return nil, fmt.Errorf("获取协议配置失败: %w", err)
@@ -92,17 +73,15 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName st
if protocol == nil { if protocol == nil {
return nil, errors.New("协议配置不存在") return nil, errors.New("协议配置不存在")
} }
providerReq, err := Compile(ir, protocol, chatModel)
providerReq, err := Compile(ir, protocol, model)
if err != nil { if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err) return nil, fmt.Errorf("编译请求失败: %w", err)
} }
fmt.Println("providerReq打印:", util.MustMarshal(providerReq))
return map[string]any{ return map[string]any{
"modelName": model.ModelName, "modelName": modelName,
"bizName": "prompts-core", "bizName": "prompts-core",
"callbackUrl": "/prompt/callback", "callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq, "requestPayload": providerReq,
}, nil }, nil
} }

View File

@@ -30,6 +30,7 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com
if err = validateUserForm(ctx, req, aiModel); err != nil { if err = validateUserForm(ctx, req, aiModel); err != nil {
return nil, err return nil, err
} }
fmt.Printf("req打印%+v", req)
switch req.BuildType { switch req.BuildType {
case public.BuildTypePrompt: case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建 return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
@@ -85,13 +86,13 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod
g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err)
continue continue
} }
//等待结果
taskRecord, err = waitForResult(ctx, taskID) taskRecord, err = waitForResult(ctx, taskID)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err)
continue continue
} }
//处理结果
message = parsePromptBuild(taskRecord, chatModel) message = parsePromptBuild(taskRecord, chatModel)
if message != nil { if message != nil {
break break
@@ -244,93 +245,87 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout) deadline := time.Now().Add(timeout)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for { for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 1. 查数据库
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID, TaskId: taskID,
}) })
if err != nil { if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err return nil, err
} }
return nil, fmt.Errorf("查询任务失败: %w", err) return nil, err
} }
if record != nil { if record != nil {
if completed, result := checkTaskCompletion(record); completed { switch record.Status {
return result, nil case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
} }
} }
if err = syncGatewayTaskState(ctx, taskID, record); err != nil { // 2. 查网关状态
g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err) state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
} }
// 3. 超时检查
if time.Now().After(deadline) { if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
} }
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case <-ticker.C: case <-time.After(pollInterval):
} }
} }
} }
// checkTaskCompletion 检查任务是否完成
func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) {
if record == nil {
return false, nil
}
switch record.Status {
case public.ComposeStatusSuccess:
return true, record
case public.ComposeStatusFailed:
errMsg := strings.TrimSpace(record.ErrorMessage)
if errMsg == "" {
return true, nil
}
return true, nil
default:
return false, nil
}
}
// syncGatewayTaskState 同步网关任务状态
func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error {
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
return fmt.Errorf("查询网关状态失败: %w", err)
}
switch state {
case 2:
return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "")
case 3:
updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败")
return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
return nil
}
// updateTaskStatus 更新任务状态
func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error {
task := &entity.ComposeTask{
TaskId: taskID,
Status: status,
}
if errorMsg != "" {
task.ErrorMessage = errorMsg
}
_, err := dao.ComposeTask.Update(ctx, task)
return err
}
// parsePromptBuild 解析提示词构建结果BuildType == 1 // parsePromptBuild 解析提示词构建结果BuildType == 1
func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult { func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult {
if taskRecord == nil { if taskRecord == nil {

View File

@@ -159,7 +159,6 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
if err != nil || entity == nil { if err != nil || entity == nil {
return nil, err return nil, err
} }
fmt.Println("entity打印", entity)
return parseProtocol(entity), nil return parseProtocol(entity), nil
} }
@@ -183,7 +182,6 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (
if ir == nil || p == nil { if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required") return nil, fmt.Errorf("ir and protocol are required")
} }
messages := mergeByOrder(ir, p.MergeOrder) messages := mergeByOrder(ir, p.MergeOrder)
messages = mapRoles(messages, p.RoleMapping) messages = mapRoles(messages, p.RoleMapping)
messages = mapContent(messages, p.ContentMapping) messages = mapContent(messages, p.ContentMapping)
@@ -279,7 +277,9 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
b, _ := json.Marshal(tmpl) b, _ := json.Marshal(tmpl)
str := string(b) str := string(b)
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
}
msgBytes, _ := json.Marshal(messages) msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))