Files
rag/common/eino/a.go
2026-04-03 18:26:20 +08:00

167 lines
4.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-ext/components/model/ark"
)
func main() {
ctx := context.Background()
// ==========================================
// 1. 初始化三大组件
// ==========================================
// 1.1 向量检索(从知识库查客服知识)
ragRetriever := NewPGVectorRetriever()
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
chatTpl := newCustomerServiceTemplate()
// 1.3 大模型ARK
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL_ID"),
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 2. 模拟会话:从 DB 读取历史对话
// ==========================================
sessionHistory := []*schema.Message{
{Role: schema.User, Content: "你们发什么快递?"},
{Role: schema.Assistant, Content: "默认发中通快递"},
{Role: schema.User, Content: "可以发顺丰吗?"},
}
// 当前用户问题
userQuery := "那顺丰需要加钱吗?"
// ==========================================
// 3. RAG 检索知识库
// ==========================================
docs, err := ragRetriever.Retrieve(ctx, userQuery)
if err != nil {
log.Fatal(err)
}
// 拼接参考知识
knowledge := ""
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
// ==========================================
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
// ==========================================
msgs, err := chatTpl.Format(ctx, map[string]any{
"history": sessionHistory,
"knowledge": knowledge,
"question": userQuery,
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 5. 流式调用大模型生成客服回答
// ==========================================
fmt.Println("\n=== 客服回复 ===")
stream, err := chatModel.Stream(ctx, msgs)
if err != nil {
log.Fatal(err)
}
fullReply := make([]*schema.Message, 0, 100)
for {
chunk, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Print(chunk.Content)
fullReply = append(fullReply, chunk)
}
// ==========================================
// 6. 拼接完整回复,存入 DB 作为新历史
// ==========================================
replyMsg, _ := schema.ConcatMessages(fullReply)
sessionHistory = append(sessionHistory,
&schema.Message{Role: schema.User, Content: userQuery},
replyMsg,
)
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
}
// ==========================================
// 本地客服提示词模板(不需要 MCP
// ==========================================
func newCustomerServiceTemplate() prompt.ChatTemplate {
// 系统提示 + 多轮对话 + 知识库 + 用户问题
return prompt.FromMessages(schema.Messages{
{
Role: schema.System,
Content: `你是电商智能客服,语气友好简洁。
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
参考知识:
{{.knowledge}}`,
},
// 历史对话会自动渲染在这里
{{range .history}}{{.}},{{end}},
// 当前用户问题
{Role: schema.User, Content: "{{.question}}"},
})
}
// ==========================================
// PGVector 检索器(简化可直接用)
// ==========================================
type PGVectorRetriever struct {
topK int
}
func NewPGVectorRetriever() retriever.Retriever {
return &PGVectorRetriever{topK: 3}
}
func (r *PGVectorRetriever) Retrieve(
ctx context.Context,
query string,
opts ...retriever.Option,
) ([]*schema.Document, error) {
options := retriever.GetCommonOptions(nil, opts...)
topK := r.topK
if options.TopK != nil {
topK = *options.TopK
}
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
// 模拟知识库
return []*schema.Document{
{
ID: "1",
Content: "顺丰快递需要补10元运费差价",
},
{
ID: "2",
Content: "订单满99元可免费升级顺丰",
},
}, nil
}