167 lines
4.2 KiB
Go
167 lines
4.2 KiB
Go
package eino
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
|
||
"github.com/cloudwego/eino/components/prompt"
|
||
"github.com/cloudwego/eino/components/retriever"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||
)
|
||
|
||
func main() {
|
||
ctx := context.Background()
|
||
|
||
// ==========================================
|
||
// 1. 初始化三大组件
|
||
// ==========================================
|
||
// 1.1 向量检索(从知识库查客服知识)
|
||
ragRetriever := NewPGVectorRetriever()
|
||
|
||
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
|
||
chatTpl := newCustomerServiceTemplate()
|
||
|
||
// 1.3 大模型(ARK)
|
||
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||
APIKey: os.Getenv("ARK_API_KEY"),
|
||
Model: os.Getenv("ARK_MODEL_ID"),
|
||
})
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
// ==========================================
|
||
// 2. 模拟会话:从 DB 读取历史对话
|
||
// ==========================================
|
||
sessionHistory := []*schema.Message{
|
||
{Role: schema.User, Content: "你们发什么快递?"},
|
||
{Role: schema.Assistant, Content: "默认发中通快递"},
|
||
{Role: schema.User, Content: "可以发顺丰吗?"},
|
||
}
|
||
|
||
// 当前用户问题
|
||
userQuery := "那顺丰需要加钱吗?"
|
||
|
||
// ==========================================
|
||
// 3. RAG 检索知识库
|
||
// ==========================================
|
||
docs, err := ragRetriever.Retrieve(ctx, userQuery)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
// 拼接参考知识
|
||
knowledge := ""
|
||
for i, doc := range docs {
|
||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||
}
|
||
|
||
// ==========================================
|
||
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
|
||
// ==========================================
|
||
msgs, err := chatTpl.Format(ctx, map[string]any{
|
||
"history": sessionHistory,
|
||
"knowledge": knowledge,
|
||
"question": userQuery,
|
||
})
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
// ==========================================
|
||
// 5. 流式调用大模型生成客服回答
|
||
// ==========================================
|
||
fmt.Println("\n=== 客服回复 ===")
|
||
stream, err := chatModel.Stream(ctx, msgs)
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
|
||
fullReply := make([]*schema.Message, 0, 100)
|
||
for {
|
||
chunk, err := stream.Recv()
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
if err != nil {
|
||
log.Fatal(err)
|
||
}
|
||
fmt.Print(chunk.Content)
|
||
fullReply = append(fullReply, chunk)
|
||
}
|
||
|
||
// ==========================================
|
||
// 6. 拼接完整回复,存入 DB 作为新历史
|
||
// ==========================================
|
||
replyMsg, _ := schema.ConcatMessages(fullReply)
|
||
sessionHistory = append(sessionHistory,
|
||
&schema.Message{Role: schema.User, Content: userQuery},
|
||
replyMsg,
|
||
)
|
||
|
||
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
|
||
}
|
||
|
||
// ==========================================
|
||
// 本地客服提示词模板(不需要 MCP)
|
||
// ==========================================
|
||
func newCustomerServiceTemplate() prompt.ChatTemplate {
|
||
// 系统提示 + 多轮对话 + 知识库 + 用户问题
|
||
return prompt.FromMessages(schema.Messages{
|
||
{
|
||
Role: schema.System,
|
||
Content: `你是电商智能客服,语气友好简洁。
|
||
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
|
||
|
||
参考知识:
|
||
{{.knowledge}}`,
|
||
},
|
||
// 历史对话会自动渲染在这里
|
||
{{range .history}}{{.}},{{end}},
|
||
// 当前用户问题
|
||
{Role: schema.User, Content: "{{.question}}"},
|
||
})
|
||
}
|
||
|
||
// ==========================================
|
||
// PGVector 检索器(简化可直接用)
|
||
// ==========================================
|
||
type PGVectorRetriever struct {
|
||
topK int
|
||
}
|
||
|
||
func NewPGVectorRetriever() retriever.Retriever {
|
||
return &PGVectorRetriever{topK: 3}
|
||
}
|
||
|
||
func (r *PGVectorRetriever) Retrieve(
|
||
ctx context.Context,
|
||
query string,
|
||
opts ...retriever.Option,
|
||
) ([]*schema.Document, error) {
|
||
|
||
options := retriever.GetCommonOptions(nil, opts...)
|
||
topK := r.topK
|
||
if options.TopK != nil {
|
||
topK = *options.TopK
|
||
}
|
||
|
||
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
|
||
// 模拟知识库
|
||
return []*schema.Document{
|
||
{
|
||
ID: "1",
|
||
Content: "顺丰快递需要补10元运费差价",
|
||
},
|
||
{
|
||
ID: "2",
|
||
Content: "订单满99元可免费升级顺丰",
|
||
},
|
||
}, nil
|
||
}
|