Files
prompts-core/service/prompt/prompt_files_handle_service.go

293 lines
7.5 KiB
Go

package prompt
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"net/http"
"prompts-core/model/dto"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/service/gateway"
)
const (
bytesPerKB = 1024
bytesPerMB = 1024 * 1024
)
// ExtractFileTexts 从 ConsultItem 列表中提取文件内容,返回拼接文本
func ExtractFileTexts(ctx context.Context, consult []dto.ConsultItem) string {
urls := make([]string, 0, len(consult))
for _, item := range consult {
if item.Url != "" {
urls = append(urls, item.Url)
}
}
return FetchFileTextsAsString(ctx, urls)
}
// FetchFileTextsAsString 从 URL 列表获取文件内容,拼接为字符串
func FetchFileTextsAsString(ctx context.Context, urls []string) string {
if len(urls) == 0 {
return ""
}
client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8)
var builder strings.Builder
for _, rawURL := range urls {
url := util.SanitizeURL(rawURL)
if url == "" || util.IsBannedExtension(url) {
continue
}
if util.IsZipExtension(url) {
for _, text := range fetchZipFileTexts(ctx, client, url) {
builder.WriteString(text)
builder.WriteString("\n")
}
continue
}
if text := fetchAndCleanFileContent(ctx, client, url); text != "" {
builder.WriteString(fmt.Sprintf("【文件:%s】\n%s\n", url, text))
}
}
return builder.String()
}
// fetchAndCleanFileContent 获取并清理文件内容
func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string {
text, err := fetchFileContent(ctx, client, url)
if err != nil || text == "" {
return ""
}
return util.CleanSymbols(text)
}
// fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
result := make(map[string]string)
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, url, maxSize)
if err != nil {
return result
}
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return result
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if shouldSkipZipEntry(file.Name) {
continue
}
if text := extractZipEntryContent(file, entryMaxSize); text != "" {
result[url+"::"+file.Name] = text
}
}
return result
}
// shouldSkipZipEntry 判断是否应该跳过 zip 条目
func shouldSkipZipEntry(fileName string) bool {
return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName)
}
// extractZipEntryContent 提取 zip 条目内容
func extractZipEntryContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if !util.IsReadableContentType(http.DetectContentType(content)) {
return ""
}
text := util.CleanSymbols(string(content))
if text == "" {
return ""
}
return text
}
// downloadFile 下载文件,限制最大大小
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
return body, nil
}
// fetchFileContent 获取单个文本文件内容
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("执行请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if !util.IsReadableContentType(contentType) {
return "", fmt.Errorf("不可读的内容类型: %s", contentType)
}
maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB
body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
if err != nil {
return "", fmt.Errorf("读取响应失败: %w", err)
}
return strings.TrimSpace(string(body)), nil
}
func SkillMdContent(ctx context.Context, skillName string) string {
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
return ""
}
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30)
maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB
zipBytes, err := downloadFile(client, fullUrl, maxSize)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] 下载失败 url=%s err=%v", fullUrl, err)
return ""
}
mdContents, err := extractMdFiles(ctx, zipBytes)
if err != nil || len(mdContents) == 0 {
g.Log().Warningf(ctx, "[SkillMd] 提取md失败 count=%d err=%v", len(mdContents), err)
return ""
}
return buildSkillMarkdown(skillResp, mdContents)
}
// buildSkillMarkdown 构建技能 Markdown 内容
func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
if skillResp.Description != "" {
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
}
for fileName, content := range mdContents {
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
builder.WriteString(content)
builder.WriteString("\n\n---\n\n")
}
return strings.TrimSpace(builder.String())
}
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
result := make(map[string]string)
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
if err != nil {
return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err)
}
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB
for _, file := range reader.File {
if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) {
continue
}
if content := readMarkdownFileContent(file, entryMaxSize); content != "" {
result[file.Name] = content
}
}
return result, nil
}
// isMarkdownFile 判断是否为 Markdown 文件
func isMarkdownFile(fileName string) bool {
return strings.HasSuffix(strings.ToLower(fileName), ".md")
}
// readMarkdownFileContent 读取 Markdown 文件内容
func readMarkdownFileContent(file *zip.File, maxSize int64) string {
rc, err := file.Open()
if err != nil {
return ""
}
defer rc.Close()
content, err := io.ReadAll(io.LimitReader(rc, maxSize))
if err != nil {
return ""
}
if len(content) == 0 {
return ""
}
return strings.TrimSpace(string(content))
}
// createHTTPClient 创建 HTTP 客户端
func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client {
timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second
return &http.Client{
Timeout: timeout,
}
}