183 lines
6.9 KiB
Go
183 lines
6.9 KiB
Go
package gateway
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"prompts-core/common/util"
|
||
"prompts-core/model/entity"
|
||
"strings"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
commonHttp "gitea.com/red-future/common/http"
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/os/gtime"
|
||
)
|
||
|
||
// CreateTaskReq 创建任务请求
|
||
type CreateTaskReq struct {
|
||
TaskId string `json:"task_id"`
|
||
State int `json:"state"`
|
||
OssFile string `json:"oss_file"`
|
||
FileType string `json:"file_type"`
|
||
Text string `json:"text"`
|
||
ErrorMsg string `json:"error_msg"`
|
||
}
|
||
|
||
// CreateGatewayTask 创建网关异步任务
|
||
func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, error) {
|
||
fullURL := "model-gateway/task/createTask"
|
||
headers := util.ForwardHeaders(ctx)
|
||
var req CreateTaskReq
|
||
body, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if err := commonHttp.Post(ctx, fullURL, headers, &req, body); err != nil {
|
||
return "", err
|
||
}
|
||
return req.TaskId, nil
|
||
}
|
||
|
||
type GetModelConfigResp struct {
|
||
Model *AsynchModel `json:"model"`
|
||
}
|
||
|
||
type AsynchModel struct {
|
||
beans.SQLBaseDO `orm:",inline"`
|
||
ModelName string `orm:"model_name" json:"modelName"`
|
||
ModelType int `orm:"model_type" json:"modelType"`
|
||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||
Form []map[string]any `orm:"form_json" json:"form"`
|
||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||
ResponseBody string `orm:"response_body" json:"responseBody"`
|
||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||
IsAsync *int `orm:"is_async" json:"isAsync"`
|
||
IsStream *int `orm:"is_stream" json:"isStream"`
|
||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||
Enabled *int `orm:"enabled" json:"enabled"`
|
||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||
}
|
||
|
||
// GetModelConfig 获取模型配置
|
||
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
|
||
fullURL := "model-gateway/model/getModel"
|
||
// 拼接 query 参数
|
||
var params []string
|
||
if req.Creator != "" {
|
||
params = append(params, fmt.Sprintf("creator=%s", req.Creator))
|
||
}
|
||
if req.ModelName != "" {
|
||
params = append(params, fmt.Sprintf("modelName=%s", req.ModelName))
|
||
}
|
||
if req.IsChatModel != 0 {
|
||
params = append(params, fmt.Sprintf("isChatModel=%d", req.IsChatModel))
|
||
}
|
||
if len(params) > 0 {
|
||
fullURL += "?" + strings.Join(params, "&")
|
||
}
|
||
headers := util.ForwardHeaders(ctx)
|
||
var resp GetModelConfigResp
|
||
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
||
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
||
}
|
||
if resp.Model == nil {
|
||
return nil, fmt.Errorf("模型不存在")
|
||
}
|
||
return resp.Model, nil
|
||
}
|
||
|
||
// GetTaskResultRes 任务结果响应
|
||
type GetTaskResultRes struct {
|
||
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
||
State int `json:"state" dc:"任务状态"`
|
||
}
|
||
|
||
// QueryGatewayTaskState 查询网关任务状态
|
||
func QueryGatewayTaskState(ctx context.Context, taskID string) (int, error) {
|
||
fullURL := fmt.Sprintf("model-gateway/task/getTaskResult?taskId=%s", taskID)
|
||
headers := util.ForwardHeaders(ctx)
|
||
var req GetTaskResultRes
|
||
if err := commonHttp.Get(ctx, fullURL, headers, &req, nil); err != nil {
|
||
return 0, err
|
||
}
|
||
return req.State, nil
|
||
}
|
||
|
||
// SkillUserVO 技能用户视图对象
|
||
type SkillUserVO struct {
|
||
Id int64 `json:"id,string"`
|
||
Name string `json:"name"`
|
||
Description string `json:"description"`
|
||
FileName string `json:"fileName"`
|
||
FileUrl string `json:"fileUrl"`
|
||
CreatedAt *gtime.Time `json:"createdAt"`
|
||
UpdatedAt *gtime.Time `json:"updatedAt"`
|
||
ImgAddressPrefix string `json:"imgAddressPrefix"`
|
||
}
|
||
|
||
// GetSkillUser 获取技能用户信息
|
||
func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
||
fullURL := fmt.Sprintf("ai-agent/skill/user/getUserOrTemplate?name=%s", name)
|
||
headers := util.ForwardHeaders(ctx)
|
||
var resp SkillUserVO
|
||
var req struct{}
|
||
if err := commonHttp.Get(ctx, fullURL, headers, &resp, req); err != nil {
|
||
return nil, err
|
||
}
|
||
return &resp, nil
|
||
}
|
||
|
||
// SendCallbackReq 发送回调的请求体
|
||
type SendCallbackReq struct {
|
||
TaskId string `json:"taskId"`
|
||
Status string `json:"status"`
|
||
Messages map[string]any `json:"messages,omitempty"`
|
||
EpicycleId int64 `json:"epicycleId"`
|
||
ErrorMsg string `json:"errorMsg,omitempty"`
|
||
}
|
||
|
||
// SendCallback 向业务方发送回调
|
||
func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycleId int64) error {
|
||
// 1. 检查回调地址
|
||
if composeTask.CallbackUrl == "" {
|
||
return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId)
|
||
}
|
||
// 2. 构造请求体
|
||
req := SendCallbackReq{
|
||
TaskId: composeTask.TaskId,
|
||
Status: composeTask.Status,
|
||
Messages: composeTask.Messages,
|
||
ErrorMsg: composeTask.ErrorMessage,
|
||
EpicycleId: epicycleId,
|
||
}
|
||
// 3. 发送 POST 请求
|
||
headers := util.ForwardHeaders(ctx)
|
||
var resp struct{}
|
||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
|
||
composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String())
|
||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||
}
|
||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||
return nil
|
||
}
|