diff --git a/db/gfdb/gfdb.go b/db/gfdb/gfdb.go index 9706a1e..8c22c1a 100644 --- a/db/gfdb/gfdb.go +++ b/db/gfdb/gfdb.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "hash/fnv" "regexp" "strings" "time" @@ -157,16 +158,38 @@ func catchSQLHook() gdb.HookHandler { } } +func getNodeIdFromIPPort(ctx context.Context) int64 { + // 获取本地IP + ip, err := utils.GetLocalIP() + if err != nil { + return 0 + } + // 获取端口 + port := g.Cfg().MustGet(ctx, "server.address").String() + // 拼接字符串 + addr := fmt.Sprintf("%s%s", ip, port) + + // 计算哈希(保证唯一且稳定) + h := fnv.New64a() + h.Write([]byte(addr)) + hashVal := h.Sum64() + + // 取模 1024 → 得到 0~1023 的合法 node + return int64(hashVal % 1024) +} + // ==================== Insert钩子 ==================== func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { - userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } - - node, err := snowflake.NewNode(g.Cfg().MustGet(ctx, "server.workerId").Int64()) + nodeId := getNodeIdFromIPPort(ctx) + if nodeId == 0 { + return nil, fmt.Errorf("nodeId cannot be empty") + } + node, err := snowflake.NewNode(nodeId) if err != nil { return nil, err } @@ -414,7 +437,6 @@ var ( type Gfdb interface { Exec(ctx context.Context, sql string, args ...any) (sql.Result, error) - GetAll(ctx context.Context, sql string, args ...any) (gdb.Result, error) Model(ctx context.Context, tableNameOrStruct ...any) *model Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error } diff --git a/utils/utils.go b/utils/utils.go index f09e32c..d7e934a 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -443,3 +443,62 @@ LOOP: time.Sleep(time.Second) goto LOOP } + +// GetLocalIP 获取本地IP ✅ 阿里云 ECS✅ 腾讯云 CVM✅ 华为云✅ 物理机✅ Docker 容器✅ K8s Pod✅ 虚拟机 +func GetLocalIP() (string, error) { + // 先获取所有网卡 + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + + // 遍历网卡,找符合条件的 + for _, iface := range ifaces { + // 跳过 禁用、回环、虚拟网卡 + if iface.Flags&net.FlagUp == 0 || // 网卡未启用 + iface.Flags&net.FlagLoopback != 0 || // 回环地址 + strings.Contains(iface.Name, "docker") || // docker 网卡 + strings.Contains(iface.Name, "veth") || // 容器虚拟网卡 + strings.Contains(iface.Name, "bridge") || // 网桥 + strings.Contains(iface.Name, "lo") { // 本地回环 + continue + } + + // 获取网卡地址 + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok || ipNet.IP.IsLoopback() { + continue + } + + ip := ipNet.IP + if ip.To4() != nil && isPrivateIP(ip) { // 只取内网 IPv4 + return ip.String(), nil + } + } + } + + return "", errors.New("cannot find valid local private IP") +} + +// 判断是否内网IP(生产必须) +func isPrivateIP(ip net.IP) bool { + privateIPBlocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + } + + for _, block := range privateIPBlocks { + _, ipNet, err := net.ParseCIDR(block) + if err == nil && ipNet.Contains(ip) { + return true + } + } + return false +}