Files
common/ragflow/chat.go

199 lines
6.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ragflow
import (
"context"
"github.com/gogf/gf/v2/errors/gerror"
)
// CreateChatReq 创建对话配置请求
type CreateChatReq struct {
Name string `json:"name"` // 对话配置名称(助理姓名)
Description string `json:"description,omitempty"` // 助理描述
DatasetIds []string `json:"dataset_ids"` // 关联的知识库ID列表
Prompt *PromptConfig `json:"prompt"` // 提示词配置
Llm *Llm `json:"llm,omitempty"` // LLM配置
}
// PromptConfig 提示词配置
type PromptConfig struct {
Prompt string `json:"prompt"` // 提示词内容
SimilarityThreshold float64 `json:"similarity_threshold"` // 相似度阈值
KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight"` // 关键词相似度权重
TopN int `json:"top_n"` // 返回顶部N个chunk
EmptyResponse string `json:"empty_response"` // 无匹配时回复必须显式传入空字符串才能让LLM自由发挥不传入会使用RAGFlow默认提示词
Opener string `json:"opener,omitempty"` // 开场白
ShowQuote bool `json:"show_quote,omitempty"` // 是否显示引用
Variables []map[string]interface{} `json:"variables,omitempty"` // 变量列表
}
// CreateChatRes 创建对话配置响应
type CreateChatRes struct {
ChatId string `json:"id"` // 对话配置ID
}
// UpdateChatReq 更新对话配置请求
type UpdateChatReq struct {
Name string `json:"name,omitempty"` // 对话配置名称
Description string `json:"description,omitempty"` // 对话描述
DatasetIds []string `json:"dataset_ids,omitempty"` // 关联的知识库ID列表RAGFlow API使用下划线格式
Prompt *PromptConfig `json:"prompt,omitempty"` // 提示词配置
}
// 聊天助手管理
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#聊天助手管理
// Chat 聊天助手结构体
type Chat struct {
Id string `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
DatasetIds []string `json:"dataset_ids"`
Llm Llm `json:"llm"`
Prompt Prompt `json:"prompt"`
Description string `json:"description"`
DoRefer string `json:"do_refer"`
Language string `json:"language"`
PromptType string `json:"prompt_type"`
Status string `json:"status"`
TenantId string `json:"tenant_id"`
TopK int `json:"top_k"`
CreateDate string `json:"create_date"`
CreateTime int64 `json:"create_time"`
UpdateDate string `json:"update_date"`
UpdateTime int64 `json:"update_time"`
}
type Llm struct {
ModelName string `json:"model_name,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
}
type Prompt struct {
SimilarityThreshold float64 `json:"similarity_threshold,omitempty"`
KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight,omitempty"`
Opener string `json:"opener,omitempty"`
Prompt string `json:"prompt,omitempty"`
RerankModel string `json:"rerank_model,omitempty"`
TopN int `json:"top_n,omitempty"`
Variables []Variable `json:"variables,omitempty"`
EmptyResponse string `json:"empty_response,omitempty"`
}
type Variable struct {
Key string `json:"key"`
Optional bool `json:"optional"`
}
// ListChatsReq 列出聊天助手请求
type ListChatsReq struct {
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
OrderBy string `json:"orderby,omitempty"`
Desc bool `json:"desc,omitempty"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
}
// ListChatsRes 列出聊天助手响应
// 注意API 不返回 total 字段,仅返回 data 数组
type ListChatsRes struct {
Code int `json:"code"` // 状态码0 表示成功
Data []*Chat `json:"data"` // 聊天助手列表
}
// DeleteChatsReq 删除聊天助手请求
type DeleteChatsReq struct {
Ids []string `json:"ids"`
}
// CreateChat 创建聊天助手
func (c *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*Chat, error) {
var res struct {
Code int `json:"code"`
Data *Chat `json:"data"`
Msg string `json:"message"`
}
if err := c.request(ctx, "POST", "/api/v1/chats", req, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("create chat failed: %s", res.Msg)
}
// 检查响应数据是否为空防止RAGFlow API返回 {"code":0, "data":null}
// 如果不检查直接返回,调用方会收到 (nil, nil),导致空指针异常
if res.Data == nil {
return nil, gerror.Newf("create chat returned null data: %s", res.Msg)
}
return res.Data, nil
}
// ListChats 列出聊天助手
func (c *Client) ListChats(ctx context.Context, req *ListChatsReq) (*ListChatsRes, error) {
path := "/api/v1/chats"
params := map[string]interface{}{}
if req.Page > 0 {
params["page"] = req.Page
}
if req.PageSize > 0 {
params["page_size"] = req.PageSize
}
if req.OrderBy != "" {
params["orderby"] = req.OrderBy
}
if req.Desc {
params["desc"] = "true"
} else {
params["desc"] = "false"
}
if req.Name != "" {
params["name"] = req.Name
}
if req.Id != "" {
params["id"] = req.Id
}
query := buildQueryString(params)
if query != "" {
path += "?" + query
}
var res ListChatsRes
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
return nil, err
}
if res.Code != 0 {
return nil, gerror.Newf("list chats failed: code=%d", res.Code)
}
return &res, nil
}
// DeleteChats 删除聊天助手
func (c *Client) DeleteChats(ctx context.Context, ids []string) (err error) {
req := DeleteChatsReq{Ids: ids}
var res CommonResponse
if err = c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("delete chats failed: %s", res.Message)
}
return
}
// UpdateChat 更新聊天助手
func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) (err error) {
var res CommonResponse
path := "/api/v1/chats/" + id
if err = c.request(ctx, "PUT", path, req, &res); err != nil {
return
}
if !res.IsSuccess() {
return gerror.Newf("update chat failed: %s", res.Message)
}
return
}