refactor(prompt): 重构提示词构建服务与数据模型

This commit is contained in:
2026-05-20 11:36:39 +08:00
parent c49144794d
commit 35bc3bd6ec
24 changed files with 1682 additions and 759 deletions

View File

@@ -8,11 +8,12 @@ import (
)
// GetModelPrompt 获取请求模型的提示词
func GetModelPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String()
func GetModelPrompt(ctx context.Context, modelType int) string {
key := "modelPrompts.types." + gconv.String(modelType)
return g.Cfg().MustGet(ctx, key, "").String()
}
// GetBuildPrompt 获取构建提示词
func GetBuildPrompt(ctx context.Context, Type int) string {
return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String()
// GetBuildPrompt 获取节点构建提示词
func GetBuildPrompt(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "nodePrompts", "").String()
}

View File

@@ -6,38 +6,41 @@ import (
"strings"
)
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
var AllowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
var (
// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀
AllowedMIMEPrefixes = []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-yaml",
"application/yaml",
"application/toml",
"application/x-httpd-php",
"application/x-sh",
"application/x-python",
"application/x-perl",
"application/x-ruby",
}
// BannedExtensions 禁止的文件扩展名
var BannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
// BannedExtensions 禁止的文件扩展名
BannedExtensions = map[string]bool{
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
".wma": true, ".m4a": true,
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
".flv": true, ".webm": true,
".tar": true, ".gz": true, ".rar": true, ".7z": true,
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
".class": true, ".pyc": true,
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
".ppt": true, ".pptx": true,
}
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
multiNewlines = regexp.MustCompile(`\n{3,}`)
)
// SanitizeURL 清洗 URL 字符串
func SanitizeURL(raw string) string {
@@ -51,25 +54,19 @@ func CleanSymbols(text string) string {
text = symbolCleaner.ReplaceAllString(text, "")
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
text = multiNewlines.ReplaceAllString(text, "\n\n")
return strings.TrimSpace(text)
}
// IsBannedExtension 判断是否为禁止的文件扩展名
func IsBannedExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
ext := extractExtension(url)
return BannedExtensions[ext]
}
// IsZipExtension 判断是否为 zip 文件
func IsZipExtension(url string) bool {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
ext := extractExtension(url)
return ext == ".zip"
}
@@ -78,11 +75,22 @@ func IsReadableContentType(contentType string) bool {
if contentType == "" {
return false
}
ct := strings.ToLower(contentType)
for _, prefix := range AllowedMIMEPrefixes {
if strings.HasPrefix(ct, prefix) {
return true
}
}
return false
}
// extractExtension 提取文件扩展名并清理查询参数
func extractExtension(url string) string {
ext := strings.ToLower(filepath.Ext(url))
if idx := strings.Index(ext, "?"); idx != -1 {
ext = ext[:idx]
}
return ext
}

View File

@@ -10,6 +10,7 @@ import (
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
func AsyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
@@ -18,9 +19,11 @@ func AsyncCtx(ctx context.Context) context.Context {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
@@ -28,25 +31,37 @@ func AsyncCtx(ctx context.Context) context.Context {
func ForwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
setHeaderFromContext(headers, ctx, "Authorization", "token")
setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
fallbackToRequestHeaders(headers, ctx)
// 兜底:从请求头获取
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}
// setHeaderFromContext 从上下文中设置 header
func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
headers[headerKey] = value
}
}
// fallbackToRequestHeaders 从请求头中获取作为兜底
func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
r := g.RequestFromCtx(ctx)
if r == nil {
return
}
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}

View File

@@ -15,6 +15,7 @@ func ParseOutput(text string) (map[string]any, error) {
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
@@ -23,26 +24,17 @@ func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
}
// 如果有 messages 字段,直接返回
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
// 否则当成单条 message
return []map[string]any{
j.Map(),
}
}
// IsMessageValid 校验推理结果是否合法
func IsMessageValid(message map[string]any) bool {
if message == nil {
return false
}
return true
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
@@ -50,6 +42,17 @@ func FormToJSON(form map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
@@ -60,39 +63,16 @@ func MustMarshal(v any) string {
if err != nil {
return "{}"
}
return string(b)
}
// ParseJSONField 解析 JSON 字段
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
// 处理 *gvar.Var 类型
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
// 统一转 map 再美化
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
@@ -101,3 +81,71 @@ func JSONPretty(v any) string {
b, _ := json.MarshalIndent(tmp, "", " ")
return string(b)
}
// GvarToMap 将 *gvar.Var 类型转换为 map[string]any
func GvarToMap(v *gvar.Var) map[string]any {
if v == nil || v.IsNil() {
return nil
}
result := make(map[string]any)
// 方法1尝试获取 map 值
if m := v.Map(); len(m) > 0 {
return m
}
// 方法2尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), &result)
if len(result) > 0 {
return result
}
}
// 方法3尝试获取 interface 再转换
if val := v.Val(); val != nil {
switch val.(type) {
case map[string]any:
return val.(map[string]any)
default:
data, _ := json.Marshal(val)
json.Unmarshal(data, &result)
}
}
return result
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}

229
common/util/token.go Normal file
View File

@@ -0,0 +1,229 @@
package util
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/gogf/gf/v2/container/gvar"
)
var (
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
punctRegex = regexp.MustCompile(`[[:punct:]]`)
)
// TokenConfig Token计算配置
type TokenConfig struct {
ZhRatio float64 `json:"zh_ratio"`
EnRatio float64 `json:"en_ratio"`
SpaceRatio float64 `json:"space_ratio"`
PunctuationRatio float64 `json:"punctuation_ratio"`
MaxWindowSize int `json:"max_window_size"`
ReserveRatio float64 `json:"reserve_ratio"`
MinReserve int `json:"min_reserve"`
}
// CalculateTokens 计算文本token数
func CalculateTokens(text string, tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 0
}
if text == "" {
return 0
}
zhCount := countChineseChars(text)
enCount := countEnglishWords(text)
spaceCount := strings.Count(text, " ")
punctCount := countPunctuation(text)
totalTokens := int(
float64(zhCount)*config.ZhRatio +
float64(enCount)*config.EnRatio +
float64(spaceCount)*config.SpaceRatio +
float64(punctCount)*config.PunctuationRatio,
)
return totalTokens
}
// CountToken 计算token是否超出窗口限制
// 返回: true - 未超出(可用), false - 已超出(不可用)
func CountToken(text string, tokenConfig any) bool {
config := parseConfig(tokenConfig)
if config == nil {
return false
}
estimatedTokens := CalculateTokens(text, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
return estimatedTokens <= availableWindow
}
// GetAvailableWindow 获取可用窗口大小
func GetAvailableWindow(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
reserve := reserveByRatio
if config.MinReserve > reserve {
reserve = config.MinReserve
}
available := config.MaxWindowSize - reserve
if available < 0 {
available = 0
}
return available
}
// GetMaxWindowSize 获取模型最大窗口大小
func GetMaxWindowSize(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
return config.MaxWindowSize
}
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
// 返回: isValid, exceedTokens, error
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return true, 0, nil
}
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
if totalTokens > availableWindow {
return false, totalTokens - availableWindow, nil
}
return true, 0, nil
}
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
// 返回: 需要拆分的批次数, 每批的token数, 错误
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return 1, nil, nil
}
availableWindow := GetAvailableWindow(tokenConfig)
batches := 1
currentTokens := 0
batchTokens := make([]int, 0)
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
itemTokens := CalculateTokens(itemStr, tokenConfig)
if currentTokens+itemTokens > availableWindow {
batchTokens = append(batchTokens, currentTokens)
batches++
currentTokens = itemTokens
} else {
currentTokens += itemTokens
}
}
if currentTokens > 0 {
batchTokens = append(batchTokens, currentTokens)
}
return batches, batchTokens, nil
}
// parseConfig 解析配置
func parseConfig(tokenConfig any) *TokenConfig {
if tokenConfig == nil {
return nil
}
switch v := tokenConfig.(type) {
case *gvar.Var:
return parseGVarConfig(v)
case map[string]any:
return parseMapConfig(v)
case *TokenConfig:
return v
case TokenConfig:
return &v
default:
return nil
}
}
// parseGVarConfig 解析 GVar 配置
func parseGVarConfig(v *gvar.Var) *TokenConfig {
if v.IsNil() {
return nil
}
mapVal := v.Map()
if mapVal == nil {
return nil
}
config := &TokenConfig{}
data, _ := json.Marshal(mapVal)
json.Unmarshal(data, config)
return config
}
// parseMapConfig 解析 Map 配置
func parseMapConfig(v map[string]any) *TokenConfig {
config := &TokenConfig{}
data, _ := json.Marshal(v)
json.Unmarshal(data, config)
return config
}
// countChineseChars 统计中文字符数量
func countChineseChars(text string) int {
count := 0
for _, r := range text {
if unicode.Is(unicode.Han, r) {
count++
}
}
return count
}
// countEnglishWords 统计英文单词数量
func countEnglishWords(text string) int {
return len(enWordRegex.FindAllString(text, -1))
}
// countPunctuation 统计标点符号数量
func countPunctuation(text string) int {
return len(punctRegex.FindAllString(text, -1))
}
// calculateUserFormTokens 计算 UserForm 总 token 数
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
totalTokens := 0
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
totalTokens += CalculateTokens(itemStr, tokenConfig)
}
return totalTokens
}