Compare commits

..

2 Commits

10 changed files with 79 additions and 206 deletions

View File

@@ -1,49 +1,13 @@
package util
import (
"fmt"
"net/url"
"prompts-core/model/entity"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentStr, ok := raw[model.ResponseBody].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
}
// 2) 解析 content 为 JSON 数组
var rounds []map[string]any
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
for i, round := range rounds {
for path, defaultValue := range requestMapping {
if !g.IsEmpty(defaultValue) {
continue
}
if gjson.New(round).Get(path).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
}
}
}
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")

View File

@@ -6,6 +6,7 @@ import (
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"prompts-core/service/gateway"
promptService "prompts-core/service/prompt"
@@ -42,7 +43,7 @@ func (c *prompt) Text(ctx context.Context, req *dto.TextReq) (res *dto.TextRes,
if err != nil {
return
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})

View File

@@ -16,8 +16,3 @@ var Session = new(session)
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return sessionService.Callback(ctx, req)
}
//TODO:后期历史相关服务可能拆分(三个接口)
// 1. 添加历史会话
// 2. 获取历史会话
// 3. 更新历史信息

View File

@@ -1,40 +0,0 @@
package dao
import (
"context"
"prompts-core/consts/public"
"prompts-core/model/entity"
"gitea.com/red-future/common/db/gfdb"
)
var Model = &modelDao{}
type modelDao struct{}
// Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// GetsByModelName 批量获取模型
func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) {
err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Creator, creator).
WhereIn(entity.AsynchModelCol.ModelName, modelNames).
Fields(fields).
Scan(&list)
return
}

View File

@@ -1,103 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
type asynchModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
IsPrivate string
IsChatModel string
IsAsync string
IsStream string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
CallbackUrl string
}
var AsynchModelCol = asynchModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
IsAsync: "is_async",
IsStream: "is_stream",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
CallbackUrl: "callback_url",
}
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
IsAsync *int `orm:"is_async" json:"isAsync"`
IsStream *int `orm:"is_stream" json:"isStream"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
}

View File

@@ -7,6 +7,7 @@ import (
"prompts-core/common/util"
"prompts-core/model/entity"
"gitea.com/red-future/common/beans"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
@@ -37,6 +38,59 @@ func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, err
return req.TaskId, nil
}
type GetModelConfigResp struct {
Model *AsynchModel `json:"model"`
}
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
IsAsync *int `orm:"is_async" json:"isAsync"`
IsStream *int `orm:"is_stream" json:"isStream"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
}
// GetModelConfig 获取模型配置
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
fmt.Println("req参数", req)
fullURL := fmt.Sprintf("model-gateway/model/getModel?creator=%s&modelName=%s&isChatModel=%d",
req.Creator, req.ModelName, req.IsChatModel)
headers := util.ForwardHeaders(ctx)
var resp GetModelConfigResp
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
return nil, fmt.Errorf("获取模型配置失败: %w", err)
}
if resp.Model == nil {
return nil, fmt.Errorf("模型不存在: creator=%s modelName=%s isChatModel=%d", req.Creator, req.ModelName, req.IsChatModel)
}
return resp.Model, nil
}
// GetTaskResultRes 任务结果响应
type GetTaskResultRes struct {
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"prompts-core/consts/public"
"prompts-core/service/gateway"
"strings"
"prompts-core/common/util"
@@ -29,7 +30,7 @@ type UserPromptPayload struct {
}
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
@@ -47,7 +48,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
@@ -69,13 +70,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
@@ -97,7 +98,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
}
// promptBuildWithRounds 构建系统提示词
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
Status: 1,
@@ -144,7 +145,7 @@ func buildUserFormContent(userForm []map[string]any) string {
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}

View File

@@ -42,14 +42,14 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com
}
// GetModelMessage 获取模型信息
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
IsChatModel: 1,
})
if err != nil {
return nil, nil, err
@@ -57,8 +57,8 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
@@ -72,7 +72,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
}
// validateUserForm 校验用户表单
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
if len(req.UserForm) == 0 {
return nil
}
@@ -90,7 +90,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
@@ -123,7 +123,7 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
@@ -148,7 +148,7 @@ func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
@@ -186,7 +186,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if err != nil {
return fmt.Errorf("查询任务失败: %w", err)
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
ModelName: composeTask.ModelName,
})

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
"prompts-core/dao"
@@ -178,7 +179,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -262,7 +263,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
@@ -273,7 +274,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)

View File

@@ -3,17 +3,17 @@ package prompt
import (
"context"
"fmt"
"prompts-core/service/gateway"
"strings"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
"prompts-core/model/dto"
"prompts-core/model/entity"
)
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
if model.TokenConfig == nil || len(req.UserForm) == 0 {
return req, 1, nil
}