重构数据引擎和报表引擎
This commit is contained in:
524
service/public/public_query_service.go
Normal file
524
service/public/public_query_service.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"dataengine/model/dto/public"
|
||||
"dataengine/model/entity/dict"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var PublicQuery = new(publicQueryService)
|
||||
|
||||
// tableColumnsCache 表定义缓存
|
||||
var tableColumnsCache = make(map[string][]string)
|
||||
|
||||
type publicQueryService struct{}
|
||||
|
||||
// Query 执行公共查询
|
||||
func (s *publicQueryService) Query(ctx context.Context, req *public.QueryReq) (res *public.QueryRes, err error) {
|
||||
// 1. 验证表名白名单
|
||||
if err = s.validateTable(ctx, req.Table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 验证字段白名单
|
||||
allowedFields, err := s.getAllowedFields(ctx, req.Table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. 构建 SELECT 部分
|
||||
selectFields := "*"
|
||||
if req.Fields != "" {
|
||||
selectFields, err = s.buildSelectFields(req.Fields, allowedFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 构建 WHERE 条件
|
||||
whereClause, whereArgs, err := s.buildWhereClause(req.Where, allowedFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. 构建 GROUP BY
|
||||
groupByClause, err := s.buildGroupBy(req.GroupBy, allowedFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 6. 强制租户过滤
|
||||
tenantClause := s.buildTenantClause(allowedFields)
|
||||
|
||||
// 7. 组合完整 WHERE
|
||||
fullWhere := tenantClause
|
||||
if whereClause != "" {
|
||||
if fullWhere != "" {
|
||||
fullWhere += " AND " + whereClause
|
||||
} else {
|
||||
fullWhere = whereClause
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 校验分页参数
|
||||
if req.Page < 1 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize < 1 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageSize > 100 {
|
||||
req.PageSize = 100
|
||||
}
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
// 9. 构建 ORDER BY
|
||||
orderByClause, err := s.buildOrderBy(req.OrderBy, allowedFields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 10. 统计总数
|
||||
countSql := fmt.Sprintf("SELECT COUNT(*) FROM %s", req.Table)
|
||||
if fullWhere != "" {
|
||||
countSql += " WHERE " + fullWhere
|
||||
}
|
||||
if groupByClause != "" {
|
||||
countSql = fmt.Sprintf("SELECT COUNT(*) FROM (SELECT 1 FROM %s WHERE %s GROUP BY %s) AS t",
|
||||
req.Table, fullWhere, groupByClause)
|
||||
}
|
||||
|
||||
result, err := gfdb.DB(ctx).GetAll(ctx, countSql, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("统计总数失败: %v", err)
|
||||
}
|
||||
var total int64
|
||||
if result.Len() > 0 {
|
||||
total = result[0]["count"].Int64()
|
||||
}
|
||||
|
||||
// 11. 查询数据
|
||||
querySql := fmt.Sprintf("SELECT %s FROM %s", selectFields, req.Table)
|
||||
if fullWhere != "" {
|
||||
querySql += " WHERE " + fullWhere
|
||||
}
|
||||
if groupByClause != "" {
|
||||
querySql += " GROUP BY " + groupByClause
|
||||
}
|
||||
if orderByClause != "" {
|
||||
querySql += " ORDER BY " + orderByClause
|
||||
}
|
||||
querySql += fmt.Sprintf(" LIMIT %d OFFSET %d", req.PageSize, offset)
|
||||
|
||||
dataResult, err := gfdb.DB(ctx).GetAll(ctx, querySql, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询数据失败: %v", err)
|
||||
}
|
||||
|
||||
var list []map[string]interface{}
|
||||
if dataResult.Len() > 0 {
|
||||
list = dataResult.List()
|
||||
}
|
||||
|
||||
return &public.QueryRes{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
Size: req.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetTableList 获取可查询表列表
|
||||
func (s *publicQueryService) GetTableList(ctx context.Context) (*public.TableListRes, error) {
|
||||
var ifaces []dict.ApiInterface
|
||||
err := gfdb.DB(ctx).Model(ctx, "api_interface").
|
||||
Where("table_definition IS NOT NULL").
|
||||
Where("table_definition->>'table_name' != ''").
|
||||
Where("status", "active").
|
||||
Scan(&ifaces)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询表列表失败: %v", err)
|
||||
}
|
||||
|
||||
// 查询平台名称
|
||||
var platforms []dict.DatasourcePlatform
|
||||
_ = gfdb.DB(ctx).Model(ctx, "api_datasource_platform").Scan(&platforms)
|
||||
platformMap := make(map[int64]string)
|
||||
for _, p := range platforms {
|
||||
platformMap[p.ID] = p.PlatformName
|
||||
}
|
||||
|
||||
var list []public.TableInfo
|
||||
for _, iface := range ifaces {
|
||||
tableName := s.getStringFromMap(iface.TableDefinition, "table_name")
|
||||
if tableName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
columns := s.extractColumnsFromMap(iface.TableDefinition)
|
||||
list = append(list, public.TableInfo{
|
||||
TableName: tableName,
|
||||
PlatformName: platformMap[iface.PlatformId],
|
||||
InterfaceName: iface.Name,
|
||||
Columns: columns,
|
||||
})
|
||||
}
|
||||
|
||||
return &public.TableListRes{List: list}, nil
|
||||
}
|
||||
|
||||
// GetColumnList 获取表字段列表
|
||||
func (s *publicQueryService) GetColumnList(ctx context.Context, tableName string) (*public.ColumnListRes, error) {
|
||||
if err := s.validateTable(ctx, tableName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns, err := s.getColumnDetails(ctx, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &public.ColumnListRes{
|
||||
TableName: tableName,
|
||||
Columns: columns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validateTable 验证表名白名单
|
||||
func (s *publicQueryService) validateTable(ctx context.Context, tableName string) error {
|
||||
if tableName == "" {
|
||||
return fmt.Errorf("表名不能为空")
|
||||
}
|
||||
|
||||
// 表名格式校验
|
||||
if matched, _ := regexp.MatchString(`^[a-zA-Z_][a-zA-Z0-9_]*$`, tableName); !matched {
|
||||
return fmt.Errorf("表名格式非法,只允许字母、数字、下划线")
|
||||
}
|
||||
|
||||
// 禁止系统表
|
||||
systemTables := []string{"pg_catalog", "information_schema"}
|
||||
for _, t := range systemTables {
|
||||
if strings.HasPrefix(strings.ToLower(tableName), t) {
|
||||
return fmt.Errorf("禁止查询系统表")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查白名单
|
||||
count, err := gfdb.DB(ctx).Model(ctx, "api_interface").
|
||||
Where("table_definition->>'table_name' = ?", tableName).
|
||||
Where("status", "active").
|
||||
Count()
|
||||
if err != nil {
|
||||
return fmt.Errorf("表名验证失败: %v", err)
|
||||
}
|
||||
if count == 0 {
|
||||
return fmt.Errorf("表 [%s] 不在可查询白名单中", tableName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllowedFields 获取表允许的字段
|
||||
func (s *publicQueryService) getAllowedFields(ctx context.Context, tableName string) ([]string, error) {
|
||||
if cols, ok := tableColumnsCache[tableName]; ok {
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
var iface dict.ApiInterface
|
||||
_, err := gfdb.DB(ctx).Model(ctx, "api_interface").
|
||||
Where("table_definition->>'table_name' = ?", tableName).
|
||||
Where("status", "active").
|
||||
One(&iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表字段失败: %v", err)
|
||||
}
|
||||
|
||||
columns := s.extractColumnsFromMap(iface.TableDefinition)
|
||||
tableColumnsCache[tableName] = columns
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// extractColumnsFromMap 从 map 中提取字段
|
||||
func (s *publicQueryService) extractColumnsFromMap(tableDef map[string]interface{}) []string {
|
||||
var columns []string
|
||||
if cols, ok := tableDef["columns"].([]interface{}); ok {
|
||||
for _, c := range cols {
|
||||
if col, ok := c.(map[string]interface{}); ok {
|
||||
if name, ok := col["name"].(string); ok {
|
||||
columns = append(columns, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
// buildSelectFields 构建 SELECT 字段
|
||||
func (s *publicQueryService) buildSelectFields(fields string, allowedFields []string) (string, error) {
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, f := range allowedFields {
|
||||
allowedMap[strings.ToLower(f)] = true
|
||||
}
|
||||
allowedMap["id"] = true
|
||||
allowedMap["tenant_id"] = true
|
||||
allowedMap["created_at"] = true
|
||||
allowedMap["updated_at"] = true
|
||||
allowedMap["raw_data"] = true
|
||||
|
||||
var result []string
|
||||
for _, f := range strings.Split(fields, ",") {
|
||||
f = strings.TrimSpace(f)
|
||||
if f == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(f, " ") || strings.Contains(f, "(") {
|
||||
result = append(result, f)
|
||||
continue
|
||||
}
|
||||
fLower := strings.ToLower(f)
|
||||
if !allowedMap[fLower] {
|
||||
return "", fmt.Errorf("字段 [%s] 不在允许列表中", f)
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return "*", nil
|
||||
}
|
||||
return strings.Join(result, ", "), nil
|
||||
}
|
||||
|
||||
// buildWhereClause 构建 WHERE 条件
|
||||
func (s *publicQueryService) buildWhereClause(where map[string]interface{}, allowedFields []string) (string, []interface{}, error) {
|
||||
if len(where) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, f := range allowedFields {
|
||||
allowedMap[strings.ToLower(f)] = true
|
||||
}
|
||||
allowedMap["tenant_id"] = true
|
||||
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for field, value := range where {
|
||||
fieldLower := strings.ToLower(field)
|
||||
if !allowedMap[fieldLower] {
|
||||
return "", nil, fmt.Errorf("字段 [%s] 不在允许列表中", field)
|
||||
}
|
||||
|
||||
// 处理操作符后缀
|
||||
opSuffixes := []struct {
|
||||
suffix string
|
||||
format string
|
||||
like bool
|
||||
}{
|
||||
{"_eq", "%s = ?", false},
|
||||
{"_ne", "%s != ?", false},
|
||||
{"_gt", "%s > ?", false},
|
||||
{"_lt", "%s < ?", false},
|
||||
{"_ge", "%s >= ?", false},
|
||||
{"_le", "%s <= ?", false},
|
||||
{"_like", "%s LIKE ?", true},
|
||||
}
|
||||
|
||||
matched := false
|
||||
for _, op := range opSuffixes {
|
||||
if strings.HasSuffix(fieldLower, op.suffix) {
|
||||
cleanField := field[:len(field)-len(op.suffix)]
|
||||
conditions = append(conditions, fmt.Sprintf(op.format, cleanField))
|
||||
if op.like {
|
||||
args = append(args, "%"+gconv.String(value)+"%")
|
||||
} else {
|
||||
args = append(args, value)
|
||||
}
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理 _in
|
||||
if strings.HasSuffix(fieldLower, "_in") {
|
||||
cleanField := field[:len(field)-3]
|
||||
if arr, ok := value.([]interface{}); ok {
|
||||
placeholders := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, v)
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("%s IN (%s)", cleanField, strings.Join(placeholders, ",")))
|
||||
} else if str, ok := value.(string); ok {
|
||||
parts := strings.Split(str, ",")
|
||||
placeholders := make([]string, len(parts))
|
||||
for i, p := range parts {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, strings.TrimSpace(p))
|
||||
}
|
||||
conditions = append(conditions, fmt.Sprintf("%s IN (%s)", cleanField, strings.Join(placeholders, ",")))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理 _between
|
||||
if strings.HasSuffix(fieldLower, "_between") {
|
||||
cleanField := field[:len(field)-8]
|
||||
if arr, ok := value.([]interface{}); ok && len(arr) >= 2 {
|
||||
conditions = append(conditions, fmt.Sprintf("%s BETWEEN ? AND ?", cleanField))
|
||||
args = append(args, arr[0], arr[1])
|
||||
} else if arr, ok := value.([]string); ok && len(arr) >= 2 {
|
||||
conditions = append(conditions, fmt.Sprintf("%s BETWEEN ? AND ?", cleanField))
|
||||
args = append(args, arr[0], arr[1])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 默认等于
|
||||
conditions = append(conditions, fmt.Sprintf("%s = ?", field))
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
return strings.Join(conditions, " AND "), args, nil
|
||||
}
|
||||
|
||||
// buildGroupBy 构建 GROUP BY
|
||||
func (s *publicQueryService) buildGroupBy(groupBy string, allowedFields []string) (string, error) {
|
||||
if groupBy == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, f := range allowedFields {
|
||||
allowedMap[strings.ToLower(f)] = true
|
||||
}
|
||||
|
||||
var fields []string
|
||||
for _, f := range strings.Split(groupBy, ",") {
|
||||
f = strings.TrimSpace(f)
|
||||
if f == "" {
|
||||
continue
|
||||
}
|
||||
if !allowedMap[strings.ToLower(f)] {
|
||||
return "", fmt.Errorf("分组字段 [%s] 不在允许列表中", f)
|
||||
}
|
||||
fields = append(fields, f)
|
||||
}
|
||||
|
||||
return strings.Join(fields, ", "), nil
|
||||
}
|
||||
|
||||
// buildOrderBy 构建 ORDER BY
|
||||
func (s *publicQueryService) buildOrderBy(orderBy string, allowedFields []string) (string, error) {
|
||||
if orderBy == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
allowedMap := make(map[string]bool)
|
||||
for _, f := range allowedFields {
|
||||
allowedMap[strings.ToLower(f)] = true
|
||||
}
|
||||
allowedMap["id"] = true
|
||||
allowedMap["created_at"] = true
|
||||
allowedMap["updated_at"] = true
|
||||
|
||||
var clauses []string
|
||||
for _, part := range strings.Split(orderBy, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Fields(part)
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
field := parts[0]
|
||||
dir := "ASC"
|
||||
if len(parts) > 1 {
|
||||
if strings.ToUpper(parts[1]) == "DESC" {
|
||||
dir = "DESC"
|
||||
}
|
||||
}
|
||||
|
||||
if !allowedMap[strings.ToLower(field)] {
|
||||
return "", fmt.Errorf("排序字段 [%s] 不在允许列表中", field)
|
||||
}
|
||||
field = regexp.MustCompile(`[^a-zA-Z0-9_]`).ReplaceAllString(field, "")
|
||||
clauses = append(clauses, field+" "+dir)
|
||||
}
|
||||
|
||||
return strings.Join(clauses, ", "), nil
|
||||
}
|
||||
|
||||
// buildTenantClause 构建租户过滤条件
|
||||
func (s *publicQueryService) buildTenantClause(allowedFields []string) string {
|
||||
for _, f := range allowedFields {
|
||||
if strings.ToLower(f) == "tenant_id" {
|
||||
return "tenant_id = 1"
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getColumnDetails 获取表字段详情
|
||||
func (s *publicQueryService) getColumnDetails(ctx context.Context, tableName string) ([]public.Column, error) {
|
||||
var iface dict.ApiInterface
|
||||
_, err := gfdb.DB(ctx).Model(ctx, "api_interface").
|
||||
Where("table_definition->>'table_name' = ?", tableName).
|
||||
Where("status", "active").
|
||||
One(&iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取表字段详情失败: %v", err)
|
||||
}
|
||||
|
||||
var columns []public.Column
|
||||
if cols, ok := iface.TableDefinition["columns"].([]interface{}); ok {
|
||||
for _, c := range cols {
|
||||
if col, ok := c.(map[string]interface{}); ok {
|
||||
columns = append(columns, public.Column{
|
||||
Name: gconv.String(col["name"]),
|
||||
Type: gconv.String(col["type"]),
|
||||
Comment: gconv.String(col["comment"]),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, public.Column{Name: "id", Type: "BIGINT", Comment: "主键ID"})
|
||||
columns = append(columns, public.Column{Name: "tenant_id", Type: "BIGINT", Comment: "租户ID"})
|
||||
columns = append(columns, public.Column{Name: "created_at", Type: "TIMESTAMP", Comment: "创建时间"})
|
||||
columns = append(columns, public.Column{Name: "updated_at", Type: "TIMESTAMP", Comment: "更新时间"})
|
||||
columns = append(columns, public.Column{Name: "raw_data", Type: "JSONB", Comment: "原始数据"})
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// getStringFromMap 从 map 中获取字符串值
|
||||
func (s *publicQueryService) getStringFromMap(data map[string]interface{}, key string) string {
|
||||
if v, ok := data[key].(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ClearTableCache 清除表缓存
|
||||
func (s *publicQueryService) ClearTableCache() {
|
||||
tableColumnsCache = make(map[string][]string)
|
||||
}
|
||||
|
||||
// InvalidateTableCache 失效指定表的缓存
|
||||
func (s *publicQueryService) InvalidateTableCache(tableName string) {
|
||||
delete(tableColumnsCache, tableName)
|
||||
}
|
||||
Reference in New Issue
Block a user