package public import ( "context" "fmt" "regexp" "strings" "dataengine/model/dto/public" "dataengine/model/entity/dict" "gitea.redpowerfuture.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/util/gconv" ) var PublicQuery = new(publicQueryService) // tableColumnsCache 表定义缓存 var tableColumnsCache = make(map[string][]string) type publicQueryService struct{} // Query 执行公共查询 func (s *publicQueryService) Query(ctx context.Context, req *public.QueryReq) (res *public.QueryRes, err error) { // 1. 验证表名白名单 if err = s.validateTable(ctx, req.Table); err != nil { return nil, err } // 2. 验证字段白名单 allowedFields, err := s.getAllowedFields(ctx, req.Table) if err != nil { return nil, err } // 3. 构建 SELECT 部分 selectFields := "*" if req.Fields != "" { selectFields, err = s.buildSelectFields(req.Fields, allowedFields) if err != nil { return nil, err } } // 4. 构建 WHERE 条件 whereClause, whereArgs, err := s.buildWhereClause(req.Where, allowedFields) if err != nil { return nil, err } // 5. 构建 GROUP BY groupByClause, err := s.buildGroupBy(req.GroupBy, allowedFields) if err != nil { return nil, err } // 6. 强制租户过滤 tenantClause := s.buildTenantClause(allowedFields) // 7. 组合完整 WHERE fullWhere := tenantClause if whereClause != "" { if fullWhere != "" { fullWhere += " AND " + whereClause } else { fullWhere = whereClause } } // 8. 校验分页参数 if req.Page < 1 { req.Page = 1 } if req.PageSize < 1 { req.PageSize = 20 } if req.PageSize > 100 { req.PageSize = 100 } offset := (req.Page - 1) * req.PageSize // 9. 构建 ORDER BY orderByClause, err := s.buildOrderBy(req.OrderBy, allowedFields) if err != nil { return nil, err } // 10. 统计总数 countSql := fmt.Sprintf("SELECT COUNT(*) FROM %s", req.Table) if fullWhere != "" { countSql += " WHERE " + fullWhere } if groupByClause != "" { countSql = fmt.Sprintf("SELECT COUNT(*) FROM (SELECT 1 FROM %s WHERE %s GROUP BY %s) AS t", req.Table, fullWhere, groupByClause) } result, err := gfdb.DB(ctx).GetAll(ctx, countSql, whereArgs...) if err != nil { return nil, fmt.Errorf("统计总数失败: %v", err) } var total int64 if result.Len() > 0 { total = result[0]["count"].Int64() } // 11. 查询数据 querySql := fmt.Sprintf("SELECT %s FROM %s", selectFields, req.Table) if fullWhere != "" { querySql += " WHERE " + fullWhere } if groupByClause != "" { querySql += " GROUP BY " + groupByClause } if orderByClause != "" { querySql += " ORDER BY " + orderByClause } querySql += fmt.Sprintf(" LIMIT %d OFFSET %d", req.PageSize, offset) dataResult, err := gfdb.DB(ctx).GetAll(ctx, querySql, whereArgs...) if err != nil { return nil, fmt.Errorf("查询数据失败: %v", err) } var list []map[string]interface{} if dataResult.Len() > 0 { list = dataResult.List() } return &public.QueryRes{ List: list, Total: total, Page: req.Page, Size: req.PageSize, }, nil } // GetTableList 获取可查询表列表 func (s *publicQueryService) GetTableList(ctx context.Context) (*public.TableListRes, error) { var ifaces []dict.ApiInterface err := gfdb.DB(ctx).Model(ctx, "api_interface"). Where("table_definition IS NOT NULL"). Where("table_definition->>'table_name' != ''"). Where("status", "active"). Scan(&ifaces) if err != nil { return nil, fmt.Errorf("查询表列表失败: %v", err) } // 查询平台名称 var platforms []dict.DatasourcePlatform _ = gfdb.DB(ctx).Model(ctx, "api_datasource_platform").Scan(&platforms) platformMap := make(map[int64]string) for _, p := range platforms { platformMap[p.ID] = p.PlatformName } var list []public.TableInfo for _, iface := range ifaces { tableName := s.getStringFromMap(iface.TableDefinition, "table_name") if tableName == "" { continue } columns := s.extractColumnsFromMap(iface.TableDefinition) list = append(list, public.TableInfo{ TableName: tableName, PlatformName: platformMap[iface.PlatformId], InterfaceName: iface.Name, Columns: columns, }) } return &public.TableListRes{List: list}, nil } // GetColumnList 获取表字段列表 func (s *publicQueryService) GetColumnList(ctx context.Context, tableName string) (*public.ColumnListRes, error) { if err := s.validateTable(ctx, tableName); err != nil { return nil, err } columns, err := s.getColumnDetails(ctx, tableName) if err != nil { return nil, err } return &public.ColumnListRes{ TableName: tableName, Columns: columns, }, nil } // validateTable 验证表名白名单 func (s *publicQueryService) validateTable(ctx context.Context, tableName string) error { if tableName == "" { return fmt.Errorf("表名不能为空") } // 表名格式校验 if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, tableName); !matched { return fmt.Errorf("表名格式非法,只允许字母、数字、下划线") } // 禁止系统表 systemTables := []string{"pg_catalog", "information_schema"} for _, t := range systemTables { if strings.HasPrefix(strings.ToLower(tableName), t) { return fmt.Errorf("禁止查询系统表") } } // 检查白名单 count, err := gfdb.DB(ctx).Model(ctx, "api_interface"). Where("table_definition->>'table_name' = ?", tableName). Where("status", "active"). Count() if err != nil { return fmt.Errorf("表名验证失败: %v", err) } if count == 0 { return fmt.Errorf("表 [%s] 不在可查询白名单中", tableName) } return nil } // getAllowedFields 获取表允许的字段 func (s *publicQueryService) getAllowedFields(ctx context.Context, tableName string) ([]string, error) { if cols, ok := tableColumnsCache[tableName]; ok { return cols, nil } var iface dict.ApiInterface _, err := gfdb.DB(ctx).Model(ctx, "api_interface"). Where("table_definition->>'table_name' = ?", tableName). Where("status", "active"). One(&iface) if err != nil { return nil, fmt.Errorf("获取表字段失败: %v", err) } columns := s.extractColumnsFromMap(iface.TableDefinition) tableColumnsCache[tableName] = columns return columns, nil } // extractColumnsFromMap 从 map 中提取字段 func (s *publicQueryService) extractColumnsFromMap(tableDef map[string]interface{}) []string { var columns []string if cols, ok := tableDef["columns"].([]interface{}); ok { for _, c := range cols { if col, ok := c.(map[string]interface{}); ok { if name, ok := col["name"].(string); ok { columns = append(columns, name) } } } } return columns } // buildSelectFields 构建 SELECT 字段 func (s *publicQueryService) buildSelectFields(fields string, allowedFields []string) (string, error) { allowedMap := make(map[string]bool) for _, f := range allowedFields { allowedMap[strings.ToLower(f)] = true } allowedMap["id"] = true allowedMap["tenant_id"] = true allowedMap["created_at"] = true allowedMap["updated_at"] = true allowedMap["raw_data"] = true var result []string for _, f := range strings.Split(fields, ",") { f = strings.TrimSpace(f) if f == "" { continue } if strings.Contains(f, " ") || strings.Contains(f, "(") { result = append(result, f) continue } fLower := strings.ToLower(f) if !allowedMap[fLower] { return "", fmt.Errorf("字段 [%s] 不在允许列表中", f) } result = append(result, f) } if len(result) == 0 { return "*", nil } return strings.Join(result, ", "), nil } // buildWhereClause 构建 WHERE 条件 func (s *publicQueryService) buildWhereClause(where map[string]interface{}, allowedFields []string) (string, []interface{}, error) { if len(where) == 0 { return "", nil, nil } allowedMap := make(map[string]bool) for _, f := range allowedFields { allowedMap[strings.ToLower(f)] = true } allowedMap["tenant_id"] = true var conditions []string var args []interface{} for field, value := range where { fieldLower := strings.ToLower(field) if !allowedMap[fieldLower] { return "", nil, fmt.Errorf("字段 [%s] 不在允许列表中", field) } // 处理操作符后缀 opSuffixes := []struct { suffix string format string like bool }{ {"_eq", "%s = ?", false}, {"_ne", "%s != ?", false}, {"_gt", "%s > ?", false}, {"_lt", "%s < ?", false}, {"_ge", "%s >= ?", false}, {"_le", "%s <= ?", false}, {"_like", "%s LIKE ?", true}, } matched := false for _, op := range opSuffixes { if strings.HasSuffix(fieldLower, op.suffix) { cleanField := field[:len(field)-len(op.suffix)] conditions = append(conditions, fmt.Sprintf(op.format, cleanField)) if op.like { args = append(args, "%"+gconv.String(value)+"%") } else { args = append(args, value) } matched = true break } } if matched { continue } // 处理 _in if strings.HasSuffix(fieldLower, "_in") { cleanField := field[:len(field)-3] if arr, ok := value.([]interface{}); ok { placeholders := make([]string, len(arr)) for i, v := range arr { placeholders[i] = "?" args = append(args, v) } conditions = append(conditions, fmt.Sprintf("%s IN (%s)", cleanField, strings.Join(placeholders, ","))) } else if str, ok := value.(string); ok { parts := strings.Split(str, ",") placeholders := make([]string, len(parts)) for i, p := range parts { placeholders[i] = "?" args = append(args, strings.TrimSpace(p)) } conditions = append(conditions, fmt.Sprintf("%s IN (%s)", cleanField, strings.Join(placeholders, ","))) } continue } // 处理 _between if strings.HasSuffix(fieldLower, "_between") { cleanField := field[:len(field)-8] if arr, ok := value.([]interface{}); ok && len(arr) >= 2 { conditions = append(conditions, fmt.Sprintf("%s BETWEEN ? AND ?", cleanField)) args = append(args, arr[0], arr[1]) } else if arr, ok := value.([]string); ok && len(arr) >= 2 { conditions = append(conditions, fmt.Sprintf("%s BETWEEN ? AND ?", cleanField)) args = append(args, arr[0], arr[1]) } continue } // 默认等于 conditions = append(conditions, fmt.Sprintf("%s = ?", field)) args = append(args, value) } return strings.Join(conditions, " AND "), args, nil } // buildGroupBy 构建 GROUP BY func (s *publicQueryService) buildGroupBy(groupBy string, allowedFields []string) (string, error) { if groupBy == "" { return "", nil } allowedMap := make(map[string]bool) for _, f := range allowedFields { allowedMap[strings.ToLower(f)] = true } var fields []string for _, f := range strings.Split(groupBy, ",") { f = strings.TrimSpace(f) if f == "" { continue } if !allowedMap[strings.ToLower(f)] { return "", fmt.Errorf("分组字段 [%s] 不在允许列表中", f) } fields = append(fields, f) } return strings.Join(fields, ", "), nil } // buildOrderBy 构建 ORDER BY func (s *publicQueryService) buildOrderBy(orderBy string, allowedFields []string) (string, error) { if orderBy == "" { return "", nil } allowedMap := make(map[string]bool) for _, f := range allowedFields { allowedMap[strings.ToLower(f)] = true } allowedMap["id"] = true allowedMap["created_at"] = true allowedMap["updated_at"] = true var clauses []string for _, part := range strings.Split(orderBy, ",") { part = strings.TrimSpace(part) if part == "" { continue } parts := strings.Fields(part) if len(parts) == 0 { continue } field := parts[0] dir := "ASC" if len(parts) > 1 { if strings.ToUpper(parts[1]) == "DESC" { dir = "DESC" } } if !allowedMap[strings.ToLower(field)] { return "", fmt.Errorf("排序字段 [%s] 不在允许列表中", field) } field = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(field, "") clauses = append(clauses, field+" "+dir) } return strings.Join(clauses, ", "), nil } // buildTenantClause 构建租户过滤条件 func (s *publicQueryService) buildTenantClause(allowedFields []string) string { for _, f := range allowedFields { if strings.ToLower(f) == "tenant_id" { return "tenant_id = 1" } } return "" } // getColumnDetails 获取表字段详情 func (s *publicQueryService) getColumnDetails(ctx context.Context, tableName string) ([]public.Column, error) { var iface dict.ApiInterface _, err := gfdb.DB(ctx).Model(ctx, "api_interface"). Where("table_definition->>'table_name' = ?", tableName). Where("status", "active"). One(&iface) if err != nil { return nil, fmt.Errorf("获取表字段详情失败: %v", err) } var columns []public.Column if cols, ok := iface.TableDefinition["columns"].([]interface{}); ok { for _, c := range cols { if col, ok := c.(map[string]interface{}); ok { columns = append(columns, public.Column{ Name: gconv.String(col["name"]), Type: gconv.String(col["type"]), Comment: gconv.String(col["comment"]), }) } } } columns = append(columns, public.Column{Name: "id", Type: "BIGINT", Comment: "主键ID"}) columns = append(columns, public.Column{Name: "tenant_id", Type: "BIGINT", Comment: "租户ID"}) columns = append(columns, public.Column{Name: "created_at", Type: "TIMESTAMP", Comment: "创建时间"}) columns = append(columns, public.Column{Name: "updated_at", Type: "TIMESTAMP", Comment: "更新时间"}) columns = append(columns, public.Column{Name: "raw_data", Type: "JSONB", Comment: "原始数据"}) return columns, nil } // getStringFromMap 从 map 中获取字符串值 func (s *publicQueryService) getStringFromMap(data map[string]interface{}, key string) string { if v, ok := data[key].(string); ok { return v } return "" } // ClearTableCache 清除表缓存 func (s *publicQueryService) ClearTableCache() { tableColumnsCache = make(map[string][]string) } // InvalidateTableCache 失效指定表的缓存 func (s *publicQueryService) InvalidateTableCache(tableName string) { delete(tableColumnsCache, tableName) }