Files
prompts-core/service/gateway/gateway_http_service.go

158 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gateway
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
// CreateTaskReq 创建任务请求
type CreateTaskReq struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Text string `json:"text"`
ErrorMsg string `json:"error_msg"`
}
// CreateGatewayTask 创建网关异步任务
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
fullURL := "model-gateway/task/createTask"
headers := util.ForwardHeaders(ctx)
var req CreateTaskReq
body, err := json.Marshal(payload)
if err != nil {
return "", err
}
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
return "", err
}
return req.TaskId, nil
}
// GetTaskResultRes 任务结果响应
type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
State int `json:"state" dc:"任务状态"`
}
// QueryGatewayTaskState 查询网关任务状态
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
headers := util.ForwardHeaders(ctx)
var req GetTaskResultRes
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
return 0, err
}
return req.State, nil
}
// SkillUserVO 技能用户视图对象
type SkillUserVO struct {
Id int64 `json:"id,string"`
Name string `json:"name"`
Description string `json:"description"`
FileName string `json:"fileName"`
FileUrl string `json:"fileUrl"`
CreatedAt *gtime.Time `json:"createdAt"`
UpdatedAt *gtime.Time `json:"updatedAt"`
ImgAddressPrefix string `json:"imgAddressPrefix"`
}
// GetSkillUser 获取技能用户信息
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
headers := util.ForwardHeaders(ctx)
var resp SkillUserVO
var req struct{}
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
return nil, err
}
return &resp, nil
}
// SendCallbackReq 发送回调的请求体
type SendCallbackReq struct {
TaskId string `json:"taskId"`
Status string `json:"status"`
Messages *MultiRoundResult `json:"messages,omitempty"`
EpicycleId int64 `json:"epicycleId"`
ErrorMsg string `json:"errorMsg,omitempty"`
}
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
// SendCallback 向业务方发送回调
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error {
// 1. 检查回调地址
if composeTask.CallbackUrl == "" {
return fmt.Errorf("回调地址为空taskId=%s", composeTask.TaskId)
}
// 2. 构造请求体
req := SendCallbackReq{
TaskId: composeTask.TaskId,
Status: composeTask.Status,
Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体
ErrorMsg: composeTask.ErrorMessage,
}
// 3. 发送 POST 请求
headers := util.ForwardHeaders(ctx)
var resp struct{}
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
composeTask.TaskId, composeTask.CallbackUrl, req.Messages)
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
}
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl)
return nil
}
// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult
func parseMessagesToResult(messages any) *MultiRoundResult {
if messages == nil {
return nil
}
var result MultiRoundResult
switch v := messages.(type) {
case *MultiRoundResult:
return v
case MultiRoundResult:
return &v
case string:
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil
}
case []byte:
if err := json.Unmarshal(v, &result); err != nil {
return nil
}
case map[string]any:
// 通过 JSON 序列化再反序列化
data, _ := json.Marshal(v)
if err := json.Unmarshal(data, &result); err != nil {
return nil
}
default:
data, err := json.Marshal(v)
if err != nil {
return nil
}
if err = json.Unmarshal(data, &result); err != nil {
return nil
}
}
return &result
}