refactor(model-gateway): 重构代码结构并优化数据库查询
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
@@ -263,6 +264,6 @@ func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL
|
|||||||
if callbackURL == "" {
|
if callbackURL == "" {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
|
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,140 +0,0 @@
|
|||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetLocalIP 获取本机有效的局域网 IPv4 地址
|
|
||||||
func GetLocalIP() string {
|
|
||||||
addrs, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
return "127.0.0.1"
|
|
||||||
}
|
|
||||||
|
|
||||||
var validIPs []string
|
|
||||||
|
|
||||||
for _, addr := range addrs {
|
|
||||||
ipnet, ok := addr.(*net.IPNet)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := ipnet.IP
|
|
||||||
|
|
||||||
if isIPValid(ip) {
|
|
||||||
validIPs = append(validIPs, ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 优先返回非 169.254.x.x 的 IP
|
|
||||||
for _, ip := range validIPs {
|
|
||||||
if !strings.HasPrefix(ip, "169.254.") {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 其次返回 169.254.x.x(最后的选择)
|
|
||||||
if len(validIPs) > 0 {
|
|
||||||
return validIPs[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return "127.0.0.1"
|
|
||||||
}
|
|
||||||
|
|
||||||
// isIPValid 判断 IP 是否有效
|
|
||||||
func isIPValid(ip net.IP) bool {
|
|
||||||
// 不是 loopback (127.0.0.1)
|
|
||||||
if ip.IsLoopback() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 是 IPv4
|
|
||||||
if ip.To4() == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是链路本地地址 (169.254.0.0/16)
|
|
||||||
if ip[0] == 169 && ip[1] == 254 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是组播地址
|
|
||||||
if ip.IsMulticast() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// 不是未指定地址 (0.0.0.0)
|
|
||||||
if ip.IsUnspecified() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalAddress 获取局域网地址(IP:端口)
|
|
||||||
func GetLocalAddress(ctx context.Context) string {
|
|
||||||
ip := GetLocalIP()
|
|
||||||
port := GetServerPort(ctx)
|
|
||||||
|
|
||||||
if port == "80" || port == "443" {
|
|
||||||
return ip
|
|
||||||
}
|
|
||||||
return ip + ":" + port
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSchemaFromRequest 从当前请求中获取协议(http/https)
|
|
||||||
func GetSchemaFromRequest(ctx context.Context) string {
|
|
||||||
r := g.RequestFromCtx(ctx)
|
|
||||||
if r == nil {
|
|
||||||
return "http"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1. 代理场景:X-Forwarded-Proto
|
|
||||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
|
||||||
return proto
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 代理场景:X-Forwarded-Scheme
|
|
||||||
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
|
|
||||||
return proto
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. TLS 连接(直接 HTTPS)
|
|
||||||
if r.TLS != nil {
|
|
||||||
return "https"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 默认 HTTP(这行很重要!)
|
|
||||||
return "http" // ← 确保有这行
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口)
|
|
||||||
func GetLocalBaseURL(ctx context.Context) string {
|
|
||||||
schema := GetSchemaFromRequest(ctx)
|
|
||||||
addr := GetLocalAddress(ctx)
|
|
||||||
return schema + "://" + addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCallbackURL 获取回调地址(完整 URL)
|
|
||||||
func GetCallbackURL(ctx context.Context, path string) string {
|
|
||||||
baseURL := GetLocalBaseURL(ctx)
|
|
||||||
// 确保 path 以 / 开头
|
|
||||||
if !strings.HasPrefix(path, "/") {
|
|
||||||
path = "/" + path
|
|
||||||
}
|
|
||||||
return baseURL + path
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetServerPort 从配置获取服务端口
|
|
||||||
func GetServerPort(ctx context.Context) string {
|
|
||||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
|
||||||
// address 格式如 ":3009",去掉冒号
|
|
||||||
if strings.HasPrefix(address, ":") {
|
|
||||||
return address[1:]
|
|
||||||
}
|
|
||||||
return "8080"
|
|
||||||
}
|
|
||||||
@@ -2,12 +2,10 @@ package dao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
@@ -58,42 +56,58 @@ func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows in
|
|||||||
return r.RowsAffected()
|
return r.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get 按ID获取(带租户隔离,只查当前租户)
|
// Get 获取模型
|
||||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
var whereCondition strings.Builder
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
var queryParams []interface{}
|
OmitEmpty().
|
||||||
if !g.IsEmpty(req.Id) {
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
|
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||||
queryParams = append(queryParams, req.Id)
|
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||||
}
|
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||||
if !g.IsEmpty(req.Creator) {
|
Fields(fields).One()
|
||||||
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
|
|
||||||
queryParams = append(queryParams, req.Creator)
|
|
||||||
}
|
|
||||||
if !g.IsEmpty(req.IsChatModel) {
|
|
||||||
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
|
|
||||||
queryParams = append(queryParams, req.IsChatModel)
|
|
||||||
}
|
|
||||||
if !g.IsEmpty(req.ModelName) {
|
|
||||||
whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
|
|
||||||
queryParams = append(queryParams, req.ModelName)
|
|
||||||
}
|
|
||||||
// 完整 SQL
|
|
||||||
sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
|
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var i []*entity.AsynchModel
|
err = r.Struct(&m)
|
||||||
if err = r.Structs(&i); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for _, item := range i {
|
|
||||||
m = item
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//// Get 按ID获取(带租户隔离,只查当前租户)
|
||||||
|
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
|
// var whereCondition strings.Builder
|
||||||
|
// var queryParams []interface{}
|
||||||
|
// if !g.IsEmpty(req.Id) {
|
||||||
|
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
|
||||||
|
// queryParams = append(queryParams, req.Id)
|
||||||
|
// }
|
||||||
|
// if !g.IsEmpty(req.Creator) {
|
||||||
|
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
|
||||||
|
// queryParams = append(queryParams, req.Creator)
|
||||||
|
// }
|
||||||
|
// if !g.IsEmpty(req.IsChatModel) {
|
||||||
|
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
|
||||||
|
// queryParams = append(queryParams, req.IsChatModel)
|
||||||
|
// }
|
||||||
|
// if !g.IsEmpty(req.ModelName) {
|
||||||
|
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
|
||||||
|
// queryParams = append(queryParams, req.ModelName)
|
||||||
|
// }
|
||||||
|
// // 完整 SQL
|
||||||
|
// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
|
||||||
|
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
|
||||||
|
// if err != nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// var i []*entity.AsynchModel
|
||||||
|
// if err = r.Structs(&i); err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// for _, item := range i {
|
||||||
|
// m = item
|
||||||
|
// }
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
|
||||||
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||||
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
apiPath := "/task/createTask"
|
apiPath := "/task/createTask"
|
||||||
httpMethod := "POST"
|
httpMethod := "POST"
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
ip = util.GetLocalIP()
|
ip = utils.GetLocalIP()
|
||||||
ua = r.UserAgent()
|
ua = r.UserAgent()
|
||||||
apiPath = r.URL.Path
|
apiPath = r.URL.Path
|
||||||
httpMethod = r.Method
|
httpMethod = r.Method
|
||||||
|
|||||||
Reference in New Issue
Block a user