Files
data-engine/common/report/builder/sql_builder.go
2026-06-11 13:06:54 +08:00

519 lines
14 KiB
Go

package builder
import (
"context"
"fmt"
"regexp"
"strings"
"dataengine/common/report/config"
"dataengine/common/report/model"
)
// SQLBuilder 动态SQL构建器
type SQLBuilder struct {
loader *config.ConfigLoader
}
// NewSQLBuilder 创建SQL构建器
func NewSQLBuilder() *SQLBuilder {
return &SQLBuilder{
loader: config.GetLoader(),
}
}
// BuildQuerySQL 根据用户选择构建查询SQL
func (b *SQLBuilder) BuildQuerySQL(ctx context.Context, req *model.UserSelectQueryReq) (string, []interface{}, map[string]interface{}, error) {
// 1. 校验配置
report, err := b.loader.GetReport(ctx, req.BusinessCode, req.ReportCode)
if err != nil {
return "", nil, nil, fmt.Errorf("获取报表配置失败: %w", err)
}
fieldMap, err := b.loader.GetFieldMap(ctx, req.BusinessCode, req.ReportCode)
if err != nil {
return "", nil, nil, fmt.Errorf("获取字段配置失败: %w", err)
}
tableName := report.StatTableName
// 2. 构建 SELECT 部分
selectClause, err := b.buildSelectClause(req, fieldMap, report)
if err != nil {
return "", nil, nil, err
}
// 3. 构建 FROM 部分
fromClause := tableName
// 4. 构建 WHERE 部分
whereClause, whereArgs, err := b.buildWhereClause(ctx, req, fieldMap, report)
if err != nil {
return "", nil, nil, err
}
// 5. 构建 GROUP BY 部分
groupByClause, err := b.buildGroupByClause(req, fieldMap)
if err != nil {
return "", nil, nil, err
}
// 6. 构建 ORDER BY 部分
orderByClause, err := b.buildOrderByClause(req, fieldMap)
if err != nil {
return "", nil, nil, err
}
// 7. 组合完整SQL
var sql strings.Builder
sql.WriteString("SELECT ")
sql.WriteString(selectClause)
sql.WriteString(" FROM ")
sql.WriteString(fromClause)
if whereClause != "" {
sql.WriteString(" WHERE ")
sql.WriteString(whereClause)
}
if groupByClause != "" {
sql.WriteString(" GROUP BY ")
sql.WriteString(groupByClause)
}
if orderByClause != "" {
sql.WriteString(" ORDER BY ")
sql.WriteString(orderByClause)
}
// 8. 统计总数SQL
countSql := "SELECT COUNT(*) FROM " + fromClause
if whereClause != "" {
countSql += " WHERE " + whereClause
}
if groupByClause != "" {
countSql = fmt.Sprintf("SELECT COUNT(*) FROM (SELECT 1 FROM %s WHERE %s GROUP BY %s) AS t",
fromClause, whereClause, groupByClause)
}
metadata := map[string]interface{}{
"countSql": countSql,
"tableName": tableName,
"reportConfig": report,
}
return sql.String(), whereArgs, metadata, nil
}
// buildSelectClause 构建SELECT子句
func (b *SQLBuilder) buildSelectClause(req *model.UserSelectQueryReq, fieldMap map[string]*model.FieldConfig, report *model.ReportConfig) (string, error) {
var selectParts []string
// 1. 添加维度字段
for _, dim := range req.Dimensions {
dim = strings.TrimSpace(dim)
if dim == "" {
continue
}
fc, ok := fieldMap[dim]
if !ok {
return "", fmt.Errorf("维度字段不存在: %s", dim)
}
if fc.FieldRole != model.RoleDimension && fc.FieldRole != model.RoleFilter {
return "", fmt.Errorf("字段 %s 不可作为维度", dim)
}
selectParts = append(selectParts, dim)
}
// 2. 添加指标字段(含聚合)
if len(req.Indicators) == 0 {
return "", fmt.Errorf("必须选择至少一个指标")
}
for _, ind := range req.Indicators {
fc, ok := fieldMap[ind.FieldCode]
if !ok {
return "", fmt.Errorf("指标字段不存在: %s", ind.FieldCode)
}
alias := ind.Alias
if alias == "" {
alias = ind.FieldCode
}
agg := strings.ToUpper(ind.Aggregate)
if agg == "" {
agg = fc.DefaultAggregate
if agg == "" {
agg = model.AggregateSum
}
}
// 校验聚合方式
if len(fc.ValidAggregates) > 0 {
valid := false
for _, v := range fc.ValidAggregates {
if strings.ToUpper(v) == agg {
valid = true
break
}
}
if !valid {
return "", fmt.Errorf("字段 %s 不支持聚合方式 %s", ind.FieldCode, agg)
}
}
// 处理衍生指标(表达式)
if fc.ExpressionType == "CALCULATED" && fc.Expression != "" {
expr := b.parseExpression(fc.Expression, req.Indicators)
selectParts = append(selectParts, fmt.Sprintf("%s AS %s", expr, alias))
} else {
selectParts = append(selectParts, fmt.Sprintf("%s(%s) AS %s", agg, ind.FieldCode, alias))
}
}
// 3. 添加时间分组字段
if req.TimeGroup != "" && req.TimeGroup != "day" {
dateField := report.DateField
if dateField == "" {
dateField = "stat_date"
}
timeGroupExpr := b.buildTimeGroupExpr(dateField, req.TimeGroup)
selectParts = append(selectParts, timeGroupExpr)
}
return strings.Join(selectParts, ", "), nil
}
// buildWhereClause 构建WHERE子句
func (b *SQLBuilder) buildWhereClause(ctx context.Context, req *model.UserSelectQueryReq, fieldMap map[string]*model.FieldConfig, report *model.ReportConfig) (string, []interface{}, error) {
var conditions []string
var args []interface{}
// 1. 租户过滤
conditions = append(conditions, "tenant_id = 1")
// 2. 时间范围过滤
if req.TimeRange != nil {
dateField := report.DateField
if dateField == "" {
dateField = "stat_date"
}
if req.TimeRange.StartDate != "" {
conditions = append(conditions, fmt.Sprintf("%s >= ?", dateField))
args = append(args, req.TimeRange.StartDate)
}
if req.TimeRange.EndDate != "" {
conditions = append(conditions, fmt.Sprintf("%s <= ?", dateField))
args = append(args, req.TimeRange.EndDate)
}
}
// 3. 业务过滤
conditions = append(conditions, "business_code = ?")
args = append(args, req.BusinessCode)
// 4. 用户筛选条件
for _, filter := range req.Filters {
fc, ok := fieldMap[filter.FieldCode]
if !ok {
return "", nil, fmt.Errorf("筛选字段不存在: %s", filter.FieldCode)
}
if !fc.IsFilterable {
return "", nil, fmt.Errorf("字段 %s 不可用于筛选", filter.FieldCode)
}
op := strings.ToUpper(filter.Operator)
if op == "" {
op = "="
}
// 校验操作符
if len(fc.FilterOperators) > 0 {
valid := false
for _, v := range fc.FilterOperators {
if strings.ToUpper(v) == op {
valid = true
break
}
}
if !valid {
return "", nil, fmt.Errorf("字段 %s 不支持操作符 %s", filter.FieldCode, op)
}
}
cond, vals, err := b.buildFilterCondition(filter, op, fc.FieldType)
if err != nil {
return "", nil, err
}
conditions = append(conditions, cond)
args = append(args, vals...)
}
return strings.Join(conditions, " AND "), args, nil
}
// buildFilterCondition 构建单个筛选条件
func (b *SQLBuilder) buildFilterCondition(filter model.FilterCondition, op string, fieldType string) (string, []interface{}, error) {
field := filter.FieldCode
var args []interface{}
switch op {
case "=":
return fmt.Sprintf("%s = ?", field), []interface{}{filter.Value}, nil
case "!=":
return fmt.Sprintf("%s != ?", field), []interface{}{filter.Value}, nil
case ">":
return fmt.Sprintf("%s > ?", field), []interface{}{filter.Value}, nil
case "<":
return fmt.Sprintf("%s < ?", field), []interface{}{filter.Value}, nil
case ">=":
return fmt.Sprintf("%s >= ?", field), []interface{}{filter.Value}, nil
case "<=":
return fmt.Sprintf("%s <= ?", field), []interface{}{filter.Value}, nil
case "IN":
values, err := b.convertToSlice(filter.Value)
if err != nil {
return "", nil, err
}
placeholders := make([]string, len(values))
for i := range values {
placeholders[i] = "?"
args = append(args, values[i])
}
return fmt.Sprintf("%s IN (%s)", field, strings.Join(placeholders, ",")), args, nil
case "LIKE":
return fmt.Sprintf("%s LIKE ?", field), []interface{}{"%" + fmt.Sprintf("%v", filter.Value) + "%"}, nil
case "BETWEEN":
return fmt.Sprintf("%s BETWEEN ? AND ?", field), []interface{}{filter.Value, filter.Value2}, nil
default:
return fmt.Sprintf("%s = ?", field), []interface{}{filter.Value}, nil
}
}
// buildGroupByClause 构建GROUP BY子句
func (b *SQLBuilder) buildGroupByClause(req *model.UserSelectQueryReq, fieldMap map[string]*model.FieldConfig) (string, error) {
var groupFields []string
for _, dim := range req.Dimensions {
fc, ok := fieldMap[dim]
if !ok {
continue
}
if fc.FieldRole == model.RoleDimension || fc.FieldRole == model.RoleFilter {
groupFields = append(groupFields, dim)
}
}
if len(groupFields) == 0 {
return "", nil
}
return strings.Join(groupFields, ", "), nil
}
// buildOrderByClause 构建ORDER BY子句
func (b *SQLBuilder) buildOrderByClause(req *model.UserSelectQueryReq, fieldMap map[string]*model.FieldConfig) (string, error) {
if len(req.OrderBy) == 0 {
return "", nil
}
var orderParts []string
for _, order := range req.OrderBy {
field := order.FieldCode
dir := strings.ToUpper(order.Direction)
if dir == "" {
dir = "ASC"
}
if dir != "ASC" && dir != "DESC" {
return "", fmt.Errorf("排序方向必须是 ASC 或 DESC")
}
fc, ok := fieldMap[field]
if !ok {
return "", fmt.Errorf("排序字段不存在: %s", field)
}
if !fc.IsSortable {
return "", fmt.Errorf("字段 %s 不可排序", field)
}
orderParts = append(orderParts, fmt.Sprintf("%s %s", field, dir))
}
return strings.Join(orderParts, ", "), nil
}
// buildTimeGroupExpr 构建时间分组表达式
func (b *SQLBuilder) buildTimeGroupExpr(dateField, timeGroup string) string {
switch timeGroup {
case "week":
return fmt.Sprintf("DATE_TRUNC('week', %s::date)::text AS time_group", dateField)
case "month":
return fmt.Sprintf("TO_CHAR(%s::date, 'YYYY-MM') AS time_group", dateField)
case "quarter":
return "TO_CHAR(" + dateField + "::date, 'YYYY-\"Q\"Q') AS time_group"
default:
return dateField + " AS time_group"
}
}
// parseExpression 解析衍生指标表达式
func (b *SQLBuilder) parseExpression(expr string, indicators []model.IndicatorSelect) string {
re := regexp.MustCompile(`\{([^}]+)\}`)
return re.ReplaceAllStringFunc(expr, func(match string) string {
fieldCode := match[1 : len(match)-1]
for _, ind := range indicators {
if ind.FieldCode == fieldCode {
return fieldCode
}
}
return match
})
}
// convertToSlice 转换为切片
func (b *SQLBuilder) convertToSlice(v interface{}) ([]interface{}, error) {
switch val := v.(type) {
case []interface{}:
return val, nil
case []string:
result := make([]interface{}, len(val))
for i, s := range val {
result[i] = s
}
return result, nil
case string:
parts := strings.Split(val, ",")
result := make([]interface{}, len(parts))
for i, p := range parts {
result[i] = strings.TrimSpace(p)
}
return result, nil
default:
return []interface{}{v}, nil
}
}
// BuildCountSQL 构建统计总数SQL
func (b *SQLBuilder) BuildCountSQL(sql string) string {
sql = regexp.MustCompile(`(?i)SELECT\s+.*?\s+FROM`).ReplaceAllString(sql, "SELECT COUNT(*) FROM")
return sql
}
// AddLimit 添加分页
func (b *SQLBuilder) AddLimit(sql string, page, pageSize int) string {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 20
}
if pageSize > 1000 {
pageSize = 1000
}
offset := (page - 1) * pageSize
return fmt.Sprintf("%s LIMIT %d OFFSET %d", sql, pageSize, offset)
}
// GenerateInsertSQL 生成upsert SQL
func (b *SQLBuilder) GenerateInsertSQL(tableName string, columns []string, conflictKeys []string) string {
cols := strings.Join(columns, ", ")
placeholders := make([]string, len(columns))
for i := range columns {
placeholders[i] = fmt.Sprintf("$%d", i+1)
}
placeholdersStr := strings.Join(placeholders, ", ")
var updateParts []string
for _, col := range columns {
if col == "id" || col == "created_at" {
continue
}
updateParts = append(updateParts, fmt.Sprintf("%s = EXCLUDED.%s", col, col))
}
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, cols, placeholdersStr)
if len(conflictKeys) > 0 {
sql += " ON CONFLICT (" + strings.Join(conflictKeys, ", ") + ")"
sql += " DO UPDATE SET " + strings.Join(updateParts, ", ")
}
return sql
}
// BuildExtractSQL 构建数据抽取SQL
func (b *SQLBuilder) BuildExtractSQL(ctx context.Context, extractConfig *model.ExtractConfig, statDate string) (string, []interface{}, error) {
var selectParts []string
var args []interface{}
// 基础字段
selectParts = append(selectParts, "tenant_id")
selectParts = append(selectParts, fmt.Sprintf("'%s' AS business_code", extractConfig.BusinessCode))
selectParts = append(selectParts, fmt.Sprintf("'%s' AS stat_date", statDate))
// 字段映射
sourceTable := extractConfig.SourceTableName
if extractConfig.SourceTableAlias != "" {
sourceTable = extractConfig.SourceTableAlias
}
for _, mapping := range extractConfig.FieldMappings {
targetField := mapping.TargetField
sourceField := mapping.SourceField
if mapping.TransformRule != nil {
expr := b.applyTransformRule(mapping.TransformRule, sourceField)
selectParts = append(selectParts, fmt.Sprintf("%s AS %s", expr, targetField))
} else {
selectParts = append(selectParts, fmt.Sprintf("%s.%s AS %s", sourceTable, sourceField, targetField))
}
}
// 构建 FROM 和 JOIN
fromClause := extractConfig.SourceTableName
if extractConfig.SourceTableAlias != "" {
fromClause += " " + extractConfig.SourceTableAlias
}
for _, join := range extractConfig.JoinConfigs {
joinType := "LEFT JOIN"
if strings.ToUpper(join.JoinType) == "INNER" {
joinType = "INNER JOIN"
} else if strings.ToUpper(join.JoinType) == "RIGHT" {
joinType = "RIGHT JOIN"
}
fromClause += fmt.Sprintf(" %s %s %s ON %s", joinType, join.JoinTable, join.JoinAlias, join.JoinCondition)
}
// WHERE 条件
whereClause := ""
if extractConfig.FilterExpression != "" {
whereClause = " WHERE " + extractConfig.FilterExpression
}
sql := fmt.Sprintf("SELECT %s FROM %s%s", strings.Join(selectParts, ", "), fromClause, whereClause)
return sql, args, nil
}
// applyTransformRule 应用转换规则
func (b *SQLBuilder) applyTransformRule(rule *model.TransformRule, sourceField string) string {
switch rule.RuleType {
case "CALCULATE":
if rule.Expression != "" {
return strings.ReplaceAll(rule.Expression, "{source}", sourceField)
}
case "FORMAT":
if rule.Format != "" {
return fmt.Sprintf("TO_CHAR(%s, '%s')", sourceField, rule.Format)
}
case "MAPPING":
// 运行时映射,需要在代码中处理
return sourceField
}
return sourceField
}