From aae46a4f296407370348cace567a9018e7a4a5d9 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 3 Jun 2026 18:37:17 +0800 Subject: [PATCH] =?UTF-8?q?refactor(model-gateway):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/mapping.go | 3 +- common/util/network.go | 140 ----------------------------------- dao/model_dao.go | 76 +++++++++++-------- service/task/task_service.go | 2 +- 4 files changed, 48 insertions(+), 173 deletions(-) delete mode 100644 common/util/network.go diff --git a/common/util/mapping.go b/common/util/mapping.go index d092e1a..7d1d410 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" @@ -263,6 +264,6 @@ func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL if callbackURL == "" { return payload } - payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback") + payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback") return payload } diff --git a/common/util/network.go b/common/util/network.go deleted file mode 100644 index ffc98d6..0000000 --- a/common/util/network.go +++ /dev/null @@ -1,140 +0,0 @@ -package util - -import ( - "context" - "net" - "strings" - - "github.com/gogf/gf/v2/frame/g" -) - -// GetLocalIP 获取本机有效的局域网 IPv4 地址 -func GetLocalIP() string { - addrs, err := net.InterfaceAddrs() - if err != nil { - return "127.0.0.1" - } - - var validIPs []string - - for _, addr := range addrs { - ipnet, ok := addr.(*net.IPNet) - if !ok { - continue - } - - ip := ipnet.IP - - if isIPValid(ip) { - validIPs = append(validIPs, ip.String()) - } - } - - // 优先返回非 169.254.x.x 的 IP - for _, ip := range validIPs { - if !strings.HasPrefix(ip, "169.254.") { - return ip - } - } - - // 其次返回 169.254.x.x(最后的选择) - if len(validIPs) > 0 { - return validIPs[0] - } - - return "127.0.0.1" -} - -// isIPValid 判断 IP 是否有效 -func isIPValid(ip net.IP) bool { - // 不是 loopback (127.0.0.1) - if ip.IsLoopback() { - return false - } - - // 是 IPv4 - if ip.To4() == nil { - return false - } - - // 不是链路本地地址 (169.254.0.0/16) - if ip[0] == 169 && ip[1] == 254 { - return false - } - - // 不是组播地址 - if ip.IsMulticast() { - return false - } - - // 不是未指定地址 (0.0.0.0) - if ip.IsUnspecified() { - return false - } - - return true -} - -// GetLocalAddress 获取局域网地址(IP:端口) -func GetLocalAddress(ctx context.Context) string { - ip := GetLocalIP() - port := GetServerPort(ctx) - - if port == "80" || port == "443" { - return ip - } - return ip + ":" + port -} - -// GetSchemaFromRequest 从当前请求中获取协议(http/https) -func GetSchemaFromRequest(ctx context.Context) string { - r := g.RequestFromCtx(ctx) - if r == nil { - return "http" - } - - // 1. 代理场景:X-Forwarded-Proto - if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { - return proto - } - - // 2. 代理场景:X-Forwarded-Scheme - if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" { - return proto - } - - // 3. TLS 连接(直接 HTTPS) - if r.TLS != nil { - return "https" - } - - // 4. 默认 HTTP(这行很重要!) - return "http" // ← 确保有这行 -} - -// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口) -func GetLocalBaseURL(ctx context.Context) string { - schema := GetSchemaFromRequest(ctx) - addr := GetLocalAddress(ctx) - return schema + "://" + addr -} - -// GetCallbackURL 获取回调地址(完整 URL) -func GetCallbackURL(ctx context.Context, path string) string { - baseURL := GetLocalBaseURL(ctx) - // 确保 path 以 / 开头 - if !strings.HasPrefix(path, "/") { - path = "/" + path - } - return baseURL + path -} - -// GetServerPort 从配置获取服务端口 -func GetServerPort(ctx context.Context) string { - address := g.Cfg().MustGet(ctx, "server.address", ":8080").String() - // address 格式如 ":3009",去掉冒号 - if strings.HasPrefix(address, ":") { - return address[1:] - } - return "8080" -} diff --git a/dao/model_dao.go b/dao/model_dao.go index 88a256d..0f741f3 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -2,12 +2,10 @@ package dao import ( "context" - "fmt" "model-gateway/consts/public" "model-gateway/model/dto" "model-gateway/model/entity" "strconv" - "strings" "gitea.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/frame/g" @@ -58,42 +56,58 @@ func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows in return r.RowsAffected() } -// Get 按ID获取(带租户隔离,只查当前租户) +// Get 获取模型 func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { - var whereCondition strings.Builder - var queryParams []interface{} - if !g.IsEmpty(req.Id) { - whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id)) - queryParams = append(queryParams, req.Id) - } - if !g.IsEmpty(req.Creator) { - whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator)) - queryParams = append(queryParams, req.Creator) - } - if !g.IsEmpty(req.IsChatModel) { - whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel)) - queryParams = append(queryParams, req.IsChatModel) - } - if !g.IsEmpty(req.ModelName) { - whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName)) - queryParams = append(queryParams, req.ModelName) - } - // 完整 SQL - sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String() - r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...) + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + OmitEmpty(). + Where(entity.AsynchModelCol.Id, req.Id). + Where(entity.AsynchModelCol.Creator, req.Creator). + Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). + Where(entity.AsynchModelCol.ModelName, req.ModelName). + Fields(fields).One() if err != nil { return } - var i []*entity.AsynchModel - if err = r.Structs(&i); err != nil { - return nil, err - } - for _, item := range i { - m = item - } + err = r.Struct(&m) return } +//// Get 按ID获取(带租户隔离,只查当前租户) +//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { +// var whereCondition strings.Builder +// var queryParams []interface{} +// if !g.IsEmpty(req.Id) { +// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id)) +// queryParams = append(queryParams, req.Id) +// } +// if !g.IsEmpty(req.Creator) { +// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator)) +// queryParams = append(queryParams, req.Creator) +// } +// if !g.IsEmpty(req.IsChatModel) { +// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel)) +// queryParams = append(queryParams, req.IsChatModel) +// } +// if !g.IsEmpty(req.ModelName) { +// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName)) +// queryParams = append(queryParams, req.ModelName) +// } +// // 完整 SQL +// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String() +// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...) +// if err != nil { +// return +// } +// var i []*entity.AsynchModel +// if err = r.Structs(&i); err != nil { +// return nil, err +// } +// for _, item := range i { +// m = item +// } +// return +//} + // GetByAcrossTenant 按ID获取(跨租户,查所有租户) func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). diff --git a/service/task/task_service.go b/service/task/task_service.go index f87c469..6aa3366 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -91,7 +91,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * apiPath := "/task/createTask" httpMethod := "POST" if r := g.RequestFromCtx(ctx); r != nil { - ip = util.GetLocalIP() + ip = utils.GetLocalIP() ua = r.UserAgent() apiPath = r.URL.Path httpMethod = r.Method