refactor(util): 重构映射工具函数并优化异步任务轮询逻辑
This commit is contained in:
@@ -1,49 +1,13 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"prompts-core/model/entity"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
|
||||||
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
|
||||||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
|
||||||
// 1) 获取校验配置,并取值
|
|
||||||
requestMapping := model.RequestMapping
|
|
||||||
contentStr, ok := raw[model.ResponseBody].(string)
|
|
||||||
if !ok || contentStr == "" {
|
|
||||||
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2) 解析 content 为 JSON 数组
|
|
||||||
var rounds []map[string]any
|
|
||||||
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
|
||||||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
|
||||||
}
|
|
||||||
if len(rounds) == 0 {
|
|
||||||
return fmt.Errorf("content 数组为空")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
|
||||||
for i, round := range rounds {
|
|
||||||
for path, defaultValue := range requestMapping {
|
|
||||||
if !g.IsEmpty(defaultValue) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if gjson.New(round).Get(path).IsNil() {
|
|
||||||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReverseMap 映射 payload 到 mapping
|
// ReverseMap 映射 payload 到 mapping
|
||||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||||
jsonObj := gjson.New("{}")
|
jsonObj := gjson.New("{}")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
|
|
||||||
promptService "prompts-core/service/prompt"
|
promptService "prompts-core/service/prompt"
|
||||||
|
|
||||||
@@ -42,7 +43,7 @@ func (c *prompt) Text(ctx context.Context, req *dto.TextReq) (res *dto.TextRes,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||||
ModelName: composeTask.ModelName,
|
ModelName: composeTask.ModelName,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -16,8 +16,3 @@ var Session = new(session)
|
|||||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||||
return sessionService.Callback(ctx, req)
|
return sessionService.Callback(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO:后期历史相关服务可能拆分(三个接口)
|
|
||||||
// 1. 添加历史会话
|
|
||||||
// 2. 获取历史会话
|
|
||||||
// 3. 更新历史信息
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ type modelDao struct{}
|
|||||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
OmitEmpty().
|
OmitEmpty().
|
||||||
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||||
@@ -26,15 +27,3 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
|
|||||||
err = r.Struct(&m)
|
err = r.Struct(&m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetsByModelName 批量获取模型
|
|
||||||
func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) {
|
|
||||||
err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
|
||||||
OmitEmpty().
|
|
||||||
Where(entity.AsynchModelCol.Creator, creator).
|
|
||||||
WhereIn(entity.AsynchModelCol.ModelName, modelNames).
|
|
||||||
Fields(fields).
|
|
||||||
Scan(&list)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
commonHttp "gitea.com/red-future/common/http"
|
commonHttp "gitea.com/red-future/common/http"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/os/gtime"
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
@@ -37,6 +38,59 @@ func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, err
|
|||||||
return req.TaskId, nil
|
return req.TaskId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GetModelConfigResp struct {
|
||||||
|
Model *AsynchModel `json:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AsynchModel struct {
|
||||||
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
|
ModelName string `orm:"model_name" json:"modelName"`
|
||||||
|
ModelType int `orm:"model_type" json:"modelType"`
|
||||||
|
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||||
|
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||||
|
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||||||
|
Form []map[string]any `orm:"form_json" json:"form"`
|
||||||
|
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||||
|
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||||
|
ResponseBody string `orm:"response_body" json:"responseBody"`
|
||||||
|
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||||
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
|
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
|
IsAsync *int `orm:"is_async" json:"isAsync"`
|
||||||
|
IsStream *int `orm:"is_stream" json:"isStream"`
|
||||||
|
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||||
|
Enabled *int `orm:"enabled" json:"enabled"`
|
||||||
|
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||||
|
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||||
|
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||||
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
|
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||||
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
|
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||||
|
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||||
|
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||||
|
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||||
|
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||||
|
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||||
|
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelConfig 获取模型配置
|
||||||
|
func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) {
|
||||||
|
fmt.Println("req参数", req)
|
||||||
|
fullURL := fmt.Sprintf("model-gateway/model/getModel?creator=%s&modelName=%s&isChatModel=%d",
|
||||||
|
req.Creator, req.ModelName, req.IsChatModel)
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
var resp GetModelConfigResp
|
||||||
|
if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil {
|
||||||
|
return nil, fmt.Errorf("获取模型配置失败: %w", err)
|
||||||
|
}
|
||||||
|
if resp.Model == nil {
|
||||||
|
return nil, fmt.Errorf("模型不存在: creator=%s modelName=%s isChatModel=%d", req.Creator, req.ModelName, req.IsChatModel)
|
||||||
|
}
|
||||||
|
return resp.Model, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetTaskResultRes 任务结果响应
|
// GetTaskResultRes 任务结果响应
|
||||||
type GetTaskResultRes struct {
|
type GetTaskResultRes struct {
|
||||||
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/consts/public"
|
"prompts-core/consts/public"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
@@ -29,7 +30,7 @@ type UserPromptPayload struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildInferenceRequest 构建推理请求
|
// buildInferenceRequest 构建推理请求
|
||||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||||
//1) 处理表单分批
|
//1) 处理表单分批
|
||||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -47,7 +48,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||||
//1) 构建系统提示词
|
//1) 构建系统提示词
|
||||||
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
@@ -69,13 +70,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||||
ir.AddUser(NodeBuild(ctx, req))
|
ir.AddUser(NodeBuild(ctx, req))
|
||||||
return compileToProviderRequest(ctx, ir, chatModel)
|
return compileToProviderRequest(ctx, ir, chatModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compileToProviderRequest 编译为 Provider 请求
|
// compileToProviderRequest 编译为 Provider 请求
|
||||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
|
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
return nil, fmt.Errorf("获取协议配置失败: %w", err)
|
||||||
@@ -97,7 +98,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// promptBuildWithRounds 构建系统提示词
|
// promptBuildWithRounds 构建系统提示词
|
||||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
|
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string {
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: chatModel.OperatorName,
|
ProviderName: chatModel.OperatorName,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
@@ -144,7 +145,7 @@ func buildUserFormContent(userForm []map[string]any) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkOverallContent 检查整体内容是否超出窗口
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
|
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
||||||
fullContent := ir.String()
|
fullContent := ir.String()
|
||||||
return util.CountToken(fullContent, model.TokenConfig)
|
return util.CountToken(fullContent, model.TokenConfig)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,14 +42,14 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetModelMessage 获取模型信息
|
// GetModelMessage 获取模型信息
|
||||||
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) {
|
func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) {
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
userInfo, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
|
||||||
}
|
}
|
||||||
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||||
IsChatModel: new(1),
|
IsChatModel: 1,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -57,8 +57,8 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
|||||||
if chatModel == nil {
|
if chatModel == nil {
|
||||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||||
}
|
}
|
||||||
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName},
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -72,7 +72,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validateUserForm 校验用户表单
|
// validateUserForm 校验用户表单
|
||||||
func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error {
|
func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error {
|
||||||
if len(req.UserForm) == 0 {
|
if len(req.UserForm) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handlePromptBuild 处理提示词构建(BuildType=1)
|
// handlePromptBuild 处理提示词构建(BuildType=1)
|
||||||
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
// 获取历史会话
|
// 获取历史会话
|
||||||
history, err := session.GetHistoryMessages(ctx, req.SessionId)
|
history, err := session.GetHistoryMessages(ctx, req.SessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -123,7 +123,7 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleNodeBuild 处理节点构建(BuildType=2)
|
// handleNodeBuild 处理节点构建(BuildType=2)
|
||||||
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
return nil, fmt.Errorf("调用推理模型失败: %w", err)
|
||||||
@@ -148,7 +148,7 @@ func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// callInferenceModel 调用推理模型
|
// callInferenceModel 调用推理模型
|
||||||
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
|
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) {
|
||||||
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
|
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
||||||
@@ -186,7 +186,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("查询任务失败: %w", err)
|
return fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||||
ModelName: composeTask.ModelName,
|
ModelName: composeTask.ModelName,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"prompts-core/dao"
|
"prompts-core/dao"
|
||||||
@@ -178,7 +179,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||||
if ir == nil || p == nil {
|
if ir == nil || p == nil {
|
||||||
return nil, fmt.Errorf("ir and protocol are required")
|
return nil, fmt.Errorf("ir and protocol are required")
|
||||||
}
|
}
|
||||||
@@ -262,7 +263,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildRequest 按 target_field 和 request_template 构建请求体
|
// buildRequest 按 target_field 和 request_template 构建请求体
|
||||||
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
|
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
if len(p.RequestTemplate) > 0 {
|
if len(p.RequestTemplate) > 0 {
|
||||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||||
}
|
}
|
||||||
@@ -273,7 +274,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
|
|||||||
}
|
}
|
||||||
|
|
||||||
// renderTemplate 简单的 {{key}} 模板替换
|
// renderTemplate 简单的 {{key}} 模板替换
|
||||||
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
|
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
b, _ := json.Marshal(tmpl)
|
b, _ := json.Marshal(tmpl)
|
||||||
str := string(b)
|
str := string(b)
|
||||||
|
|
||||||
|
|||||||
@@ -3,17 +3,17 @@ package prompt
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/model/dto"
|
"prompts-core/model/dto"
|
||||||
"prompts-core/model/entity"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容)
|
||||||
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) {
|
||||||
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
if model.TokenConfig == nil || len(req.UserForm) == 0 {
|
||||||
return req, 1, nil
|
return req, 1, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user