数字人项目迁移
This commit is contained in:
219
digitalhuman/service/http_wrapper.go
Normal file
219
digitalhuman/service/http_wrapper.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
stdhttp "net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
var commonHttpTransportMu sync.Mutex
|
||||
|
||||
// asyncCtx 异步上下文处理
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
}
|
||||
if user, uErr := utils.GetUserInfo(ctx); uErr == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// setCommonHttpResponseHeaderTimeout 调整公共 HTTP 客户端响应头超时,避免长时推理被 30s 默认值打断。
|
||||
func setCommonHttpResponseHeaderTimeout(d time.Duration) {
|
||||
if d <= 0 {
|
||||
return
|
||||
}
|
||||
commonHttpTransportMu.Lock()
|
||||
defer commonHttpTransportMu.Unlock()
|
||||
if tr, ok := commonHttp.Httpclient.Transport.(*stdhttp.Transport); ok && tr != nil {
|
||||
if tr.ResponseHeaderTimeout < d {
|
||||
tr.ResponseHeaderTimeout = d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息,优先使用异步上下文里固化的 token。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// commonPostJSON 使用 common/http 的底层客户端直连 JSON 接口,适配非统一响应包装结构。
|
||||
func commonPostJSON(ctx context.Context, url string, headers map[string]string, req any, resp any) error {
|
||||
client := commonHttp.Httpclient.Clone().ContentJson()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if d := time.Until(deadline); d > 0 {
|
||||
client.SetTimeout(d)
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
client.SetHeaderMap(headers)
|
||||
}
|
||||
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return gerror.Wrap(err, "读取响应失败")
|
||||
}
|
||||
if r.StatusCode != stdhttp.StatusOK {
|
||||
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(body))
|
||||
}
|
||||
if err := json.Unmarshal(body, resp); err != nil {
|
||||
return gerror.Wrapf(err, "解析响应失败, body: %s", string(body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func commonPostMultipartFile(ctx context.Context, url string, headers map[string]string, form map[string]string, fileField string, filePath string, resp any) error {
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
for k, v := range form {
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if err := writer.WriteField(k, v); err != nil {
|
||||
return gerror.Wrapf(err, "写入表单字段失败: %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return gerror.Wrapf(err, "打开文件失败: %s", filePath)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
part, err := writer.CreateFormFile(fileField, filepath.Base(filePath))
|
||||
if err != nil {
|
||||
return gerror.Wrapf(err, "创建表单文件失败: %s", fileField)
|
||||
}
|
||||
if _, err := io.Copy(part, f); err != nil {
|
||||
return gerror.Wrap(err, "写入文件内容失败")
|
||||
}
|
||||
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return gerror.Wrap(err, "关闭表单写入器失败")
|
||||
}
|
||||
|
||||
client := commonHttp.Httpclient.Clone()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if d := time.Until(deadline); d > 0 {
|
||||
client.SetTimeout(d)
|
||||
}
|
||||
}
|
||||
if headers == nil {
|
||||
headers = make(map[string]string)
|
||||
}
|
||||
headers["Content-Type"] = contentType
|
||||
client.SetHeaderMap(headers)
|
||||
|
||||
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, body.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return gerror.Wrap(err, "读取响应失败")
|
||||
}
|
||||
if r.StatusCode != stdhttp.StatusOK {
|
||||
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(raw))
|
||||
}
|
||||
if err := json.Unmarshal(raw, resp); err != nil {
|
||||
return gerror.Wrapf(err, "解析响应失败, body: %s", string(raw))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------------- model-asynch 调用封装 --------------------------
|
||||
|
||||
const modelAsynchServiceName = "model-asynch"
|
||||
|
||||
type modelAsynchCreateTaskReq struct {
|
||||
ModelName string `json:"modelName"`
|
||||
InputRef string `json:"inputRef,omitempty"`
|
||||
RequestPayload any `json:"requestPayload"`
|
||||
}
|
||||
|
||||
type modelAsynchCreateTaskRes struct {
|
||||
TaskID string `json:"taskId"`
|
||||
}
|
||||
|
||||
// createModelAsynchTask 调用 model-asynch 创建任务
|
||||
// 注意:路由以 GoFrame 默认输出为准(通常为 /task/create-task)
|
||||
func createModelAsynchTask(ctx context.Context, modelName string, payload any, inputRef string) (taskID string, err error) {
|
||||
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
|
||||
headers := forwardHeaders(ctx)
|
||||
req := &modelAsynchCreateTaskReq{
|
||||
ModelName: modelName,
|
||||
InputRef: inputRef,
|
||||
RequestPayload: payload,
|
||||
}
|
||||
var res modelAsynchCreateTaskRes
|
||||
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/createTask", taskUrl), headers, &res, req); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return res.TaskID, nil
|
||||
}
|
||||
|
||||
type modelAsynchBatchReq struct {
|
||||
TaskIDs []string `json:"taskIds"`
|
||||
}
|
||||
|
||||
type modelAsynchBatchItem struct {
|
||||
TaskID string `json:"taskId"`
|
||||
State int `json:"state"`
|
||||
OssFile string `json:"ossFile"`
|
||||
}
|
||||
|
||||
type modelAsynchBatchRes struct {
|
||||
List []modelAsynchBatchItem `json:"list"`
|
||||
}
|
||||
|
||||
// getModelAsynchTaskBatch 批量查询任务(成功 2->4 的逻辑由中间件内部处理)
|
||||
func getModelAsynchTaskBatch(ctx context.Context, taskIDs []string) (items []modelAsynchBatchItem, err error) {
|
||||
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
|
||||
headers := forwardHeaders(ctx)
|
||||
req := &modelAsynchBatchReq{TaskIDs: taskIDs}
|
||||
var res modelAsynchBatchRes
|
||||
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/getTaskBatch", taskUrl), headers, &res, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.List, nil
|
||||
}
|
||||
Reference in New Issue
Block a user