feat: rag初始版
This commit is contained in:
177
common/eino/a.go
Normal file
177
common/eino/a.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"rag/dao"
|
||||||
|
"rag/model/dto"
|
||||||
|
"rag/model/entity"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
|
"github.com/cloudwego/eino/callbacks"
|
||||||
|
"github.com/cloudwego/eino/components/indexer"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
"github.com/pgvector/pgvector-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PGVectorIndexerOptions struct {
|
||||||
|
BatchSize int // 每批处理多少条
|
||||||
|
}
|
||||||
|
|
||||||
|
type PGVectorIndexer struct {
|
||||||
|
opts *PGVectorIndexerOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
|
||||||
|
// 默认值
|
||||||
|
if opts.BatchSize <= 0 {
|
||||||
|
opts.BatchSize = 5
|
||||||
|
}
|
||||||
|
return &PGVectorIndexer{opts: opts}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
|
||||||
|
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
|
||||||
|
|
||||||
|
if commonOpts.Embedding == nil {
|
||||||
|
return 0, errors.New("embedding model not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回调
|
||||||
|
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||||
|
|
||||||
|
ids, err := i.bulkStore(ctx, docs, commonOpts)
|
||||||
|
if err != nil {
|
||||||
|
callbacks.OnError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(ids)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||||||
|
var batchDocs []*schema.Document
|
||||||
|
|
||||||
|
// 官方ES同款逻辑:满 BatchSize 就处理一批
|
||||||
|
for _, doc := range docs {
|
||||||
|
batchDocs = append(batchDocs, doc)
|
||||||
|
|
||||||
|
// 满了 → 处理
|
||||||
|
if len(batchDocs) >= i.opts.BatchSize {
|
||||||
|
var r int64
|
||||||
|
r, err = i.doStore(ctx, batchDocs, opts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rows = rows + r
|
||||||
|
batchDocs = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最后一批
|
||||||
|
if len(batchDocs) > 0 {
|
||||||
|
var r int64
|
||||||
|
r, err = i.doStore(ctx, batchDocs, opts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rows = rows + r
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||||||
|
|
||||||
|
texts := make([]string, len(docs))
|
||||||
|
for i, d := range docs {
|
||||||
|
texts[i] = d.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
// 向量化(官方ES也没有重试!)
|
||||||
|
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转成业务实体
|
||||||
|
var chunks []*dto.VectorDocumentChunkMsg
|
||||||
|
for idx, doc := range docs {
|
||||||
|
ck := new(dto.VectorDocumentChunkMsg)
|
||||||
|
err = gconv.Struct(doc.MetaData, ck)
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "doStore err: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ck.Content = doc.Content
|
||||||
|
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
|
||||||
|
ck.VectorStatus = gconv.PtrInt8(1)
|
||||||
|
ck.Status = gconv.PtrInt8(1)
|
||||||
|
chunks = append(chunks, ck)
|
||||||
|
}
|
||||||
|
if len(chunks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||||
|
TenantId: chunks[0].TenantId,
|
||||||
|
UserName: chunks[0].Creator,
|
||||||
|
})
|
||||||
|
// 创建索引
|
||||||
|
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 入库
|
||||||
|
rows, err = dao.DocumentChunk.BatchInsert(ctx, chunks)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
|
||||||
|
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if exist != nil {
|
||||||
|
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
|
||||||
|
idx := &entity.DatasetIndex{
|
||||||
|
DatasetId: datasetId,
|
||||||
|
Name: indexName,
|
||||||
|
Dimension: dimension,
|
||||||
|
FieldType: "float",
|
||||||
|
MetricType: "COSINE",
|
||||||
|
Status: gconv.PtrInt8(1),
|
||||||
|
VectorCount: vectorCount,
|
||||||
|
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||||||
|
}
|
||||||
|
_, err = dao.DatasetIndex.Insert(ctx, idx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return i.createRealPGVectorIndex(ctx, indexName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||||||
|
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
|
||||||
|
glog.Errorf(ctx, "create vector index failed: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
glog.Infof(ctx, "created pgvector index: %s", indexName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) GetType() string {
|
||||||
|
return "pgvector_indexer"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
107
common/eino/b.go
Normal file
107
common/eino/b.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/elastic/go-elasticsearch/v8"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/indexer/es8"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
indexName = "eino_example"
|
||||||
|
fieldContent = "content"
|
||||||
|
fieldContentVector = "content_vector"
|
||||||
|
fieldExtraLocation = "location"
|
||||||
|
docExtraLocation = "location"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIndexer() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 1. 创建 ES 客户端
|
||||||
|
client, err := elasticsearch.NewClient(elasticsearch.Config{
|
||||||
|
Addresses: []string{"http://localhost:9200"},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("create client error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 定义 Index Spec(选填:如果索引不存在,将自动创建)
|
||||||
|
indexSpec := &es8.IndexSpec{
|
||||||
|
Settings: map[string]any{
|
||||||
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0,
|
||||||
|
},
|
||||||
|
Mappings: map[string]any{
|
||||||
|
"properties": map[string]any{
|
||||||
|
fieldContentVector: map[string]any{
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": 1024,
|
||||||
|
"index": true,
|
||||||
|
"similarity": "l2_norm",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 准备文档
|
||||||
|
// 文档通常包含 ID 和 Content
|
||||||
|
// 也可以包含额外的 Metadata 用于过滤或其他用途
|
||||||
|
docs := []*schema.Document{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
Content: "Eiffel Tower: Located in Paris, France.",
|
||||||
|
MetaData: map[string]any{
|
||||||
|
docExtraLocation: "France",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
Content: "The Great Wall: Located in China.",
|
||||||
|
MetaData: map[string]any{
|
||||||
|
docExtraLocation: "China",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 创建 ES 索引器组件
|
||||||
|
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
|
||||||
|
Client: client,
|
||||||
|
Index: indexName,
|
||||||
|
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
|
||||||
|
BatchSize: 10,
|
||||||
|
// DocumentToFields 指定如何将文档字段映射到 ES 字段
|
||||||
|
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
|
||||||
|
return map[string]es8.FieldValue{
|
||||||
|
fieldContent: {
|
||||||
|
Value: doc.Content,
|
||||||
|
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
|
||||||
|
},
|
||||||
|
fieldExtraLocation: {
|
||||||
|
// 额外的 metadata 字段
|
||||||
|
Value: doc.MetaData[docExtraLocation],
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
// 提供 embedding 组件用于向量化
|
||||||
|
Embedding: EmbedderDashscope,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("create indexer error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 索引文档
|
||||||
|
ids, err := indexer.Store(ctx, docs)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("index error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Println("indexed ids:", ids)
|
||||||
|
}
|
||||||
49
common/eino/base_task.go
Normal file
49
common/eino/base_task.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseTask 任务基类 - MongoDB版本
|
||||||
|
type BaseTask struct {
|
||||||
|
beans.MongoBaseDO `bson:",inline"`
|
||||||
|
// 任务信息
|
||||||
|
TaskType TaskType `bson:"taskType" json:"taskType"`
|
||||||
|
Status TaskStatus `bson:"status" json:"status"`
|
||||||
|
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
// 进度
|
||||||
|
TotalItems int64 `bson:"totalItems" json:"totalItems"`
|
||||||
|
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
|
||||||
|
Progress float64 `bson:"progress" json:"progress"`
|
||||||
|
// 结果
|
||||||
|
StartTime *time.Time `bson:"startTime" json:"startTime"`
|
||||||
|
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
|
||||||
|
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
|
||||||
|
SuccessCount int64 `bson:"successCount" json:"successCount"`
|
||||||
|
FailCount int64 `bson:"failCount" json:"failCount"`
|
||||||
|
// 其他
|
||||||
|
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLBaseTask 任务基类 - SQL版本
|
||||||
|
type SQLBaseTask struct {
|
||||||
|
beans.SQLBaseDO
|
||||||
|
// 任务信息
|
||||||
|
TaskType TaskType `json:"taskType"`
|
||||||
|
Status TaskStatus `json:"status"`
|
||||||
|
Priority TaskPriority `json:"priority,omitempty"`
|
||||||
|
// 进度
|
||||||
|
TotalItems int64 `json:"totalItems"`
|
||||||
|
ProcessedItems int64 `json:"processedItems"`
|
||||||
|
Progress float64 `json:"progress"`
|
||||||
|
// 结果
|
||||||
|
StartTime *time.Time `json:"startTime"`
|
||||||
|
EndTime *time.Time `json:"endTime,omitempty"`
|
||||||
|
Duration int64 `json:"duration,omitempty"`
|
||||||
|
SuccessCount int64 `json:"successCount"`
|
||||||
|
FailCount int64 `json:"failCount"`
|
||||||
|
// 其他
|
||||||
|
Executor string `json:"executor,omitempty"`
|
||||||
|
}
|
||||||
94
common/eino/c.go
Normal file
94
common/eino/c.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/elastic/go-elasticsearch/v8"
|
||||||
|
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
||||||
|
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetriever() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
||||||
|
Addresses: []string{"http://localhost:9200"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// 创建 retriever 组件
|
||||||
|
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
||||||
|
Client: client,
|
||||||
|
Index: indexName,
|
||||||
|
TopK: 5,
|
||||||
|
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
||||||
|
QueryFieldName: fieldContent,
|
||||||
|
VectorFieldName: fieldContentVector,
|
||||||
|
Hybrid: false,
|
||||||
|
// RRF 仅在特定许可证下可用
|
||||||
|
// 参见: https://www.elastic.co/subscriptions
|
||||||
|
RRF: false,
|
||||||
|
RRFRankConstant: nil,
|
||||||
|
RRFWindowSize: nil,
|
||||||
|
}),
|
||||||
|
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
||||||
|
doc = &schema.Document{
|
||||||
|
ID: *hit.Id_,
|
||||||
|
Content: "",
|
||||||
|
MetaData: map[string]any{},
|
||||||
|
}
|
||||||
|
|
||||||
|
var src map[string]any
|
||||||
|
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, val := range src {
|
||||||
|
switch field {
|
||||||
|
case fieldContent:
|
||||||
|
doc.Content = val.(string)
|
||||||
|
case fieldContentVector:
|
||||||
|
var v []float64
|
||||||
|
for _, item := range val.([]interface{}) {
|
||||||
|
v = append(v, item.(float64))
|
||||||
|
}
|
||||||
|
doc.WithDenseVector(v)
|
||||||
|
case fieldExtraLocation:
|
||||||
|
doc.MetaData[docExtraLocation] = val.(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hit.Score_ != nil {
|
||||||
|
doc.WithScore(float64(*hit.Score_))
|
||||||
|
}
|
||||||
|
|
||||||
|
return doc, nil
|
||||||
|
},
|
||||||
|
Embedding: EmbedderDashscope,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 不带过滤器的搜索
|
||||||
|
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
||||||
|
|
||||||
|
// 带过滤器的搜索
|
||||||
|
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
||||||
|
es8.WithFilters([]types.Query{{
|
||||||
|
Term: map[string]types.TermQuery{
|
||||||
|
fieldExtraLocation: {
|
||||||
|
CaseInsensitive: of(true),
|
||||||
|
Value: "China",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}}),
|
||||||
|
)
|
||||||
|
|
||||||
|
fmt.Printf("retrieved docs: %+v\n", docs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func of[T any](v T) *T {
|
||||||
|
return &v
|
||||||
|
}
|
||||||
8
common/eino/consts.go
Normal file
8
common/eino/consts.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
const (
|
||||||
|
providerArk = "ark"
|
||||||
|
providerOpenai = "openai"
|
||||||
|
providerQianfan = "qianfan"
|
||||||
|
providerDashscope = "dashscope"
|
||||||
|
)
|
||||||
51
common/eino/document_loader.go
Normal file
51
common/eino/document_loader.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/utils"
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/loader/url"
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/parser/docx"
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/parser/pdf"
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/parser/xlsx"
|
||||||
|
"github.com/cloudwego/eino/components/document"
|
||||||
|
"github.com/cloudwego/eino/components/document/parser"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadDocument 业务函数:加载文件
|
||||||
|
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
|
||||||
|
p, err := docsParser(ctx, fileFormat)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
loader, err := url.NewLoader(ctx, &url.LoaderConfig{
|
||||||
|
Parser: p,
|
||||||
|
})
|
||||||
|
imageUrl, err := utils.GetFileAddressPrefix(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
docs, err = loader.Load(context.Background(), document.Source{
|
||||||
|
URI: fmt.Sprintf("%s%s", imageUrl, filePath),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) {
|
||||||
|
switch fileFormat {
|
||||||
|
case "docx":
|
||||||
|
p, err = docx.NewDocxParser(ctx, &docx.Config{
|
||||||
|
ToSections: true,
|
||||||
|
IncludeHeaders: true,
|
||||||
|
IncludeFooters: true,
|
||||||
|
IncludeTables: true,
|
||||||
|
})
|
||||||
|
case "pdf":
|
||||||
|
p, err = pdf.NewPDFParser(ctx, &pdf.Config{})
|
||||||
|
case "xlsx":
|
||||||
|
p, err = xlsx.NewXlsxParser(ctx, &xlsx.Config{})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
64
common/eino/document_semantic.go
Normal file
64
common/eino/document_semantic.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||||
|
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SemanticSplitDocument 语义分割文档
|
||||||
|
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||||
|
// 默认分隔符(支持中英文)
|
||||||
|
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||||
|
// 读取配置,使用合理的默认值
|
||||||
|
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
|
||||||
|
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
|
||||||
|
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
|
||||||
|
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
|
||||||
|
if batchSize <= 0 {
|
||||||
|
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用批量包装器
|
||||||
|
var batchEmbedder *BatchEmbedder
|
||||||
|
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||||
|
switch provider {
|
||||||
|
case providerArk:
|
||||||
|
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
|
||||||
|
case providerOpenai:
|
||||||
|
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
|
||||||
|
case providerDashscope:
|
||||||
|
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
|
||||||
|
Embedding: batchEmbedder,
|
||||||
|
BufferSize: bufferSize,
|
||||||
|
MinChunkSize: minChunkSize,
|
||||||
|
Percentile: percentile,
|
||||||
|
Separators: separators,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return splitter.Transform(ctx, docs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecursiveSplitDocument 递归分割文档
|
||||||
|
func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||||
|
// 默认分隔符(支持中英文)
|
||||||
|
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||||
|
splitter, err := recursive.NewSplitter(ctx, &recursive.Config{
|
||||||
|
ChunkSize: 512,
|
||||||
|
OverlapSize: 100,
|
||||||
|
KeepType: recursive.KeepTypeNone,
|
||||||
|
Separators: separators,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return splitter.Transform(ctx, docs)
|
||||||
|
}
|
||||||
69
common/eino/embedding.go
Normal file
69
common/eino/embedding.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||||
|
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
||||||
|
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/golang/glog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 全局只初始化一次
|
||||||
|
var (
|
||||||
|
EmbedderArk *ark.Embedder
|
||||||
|
EmbedderDashscope *dashscope.Embedder
|
||||||
|
EmbedderOpenAI *openai.Embedder
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
ctx := context.Background()
|
||||||
|
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
||||||
|
var err error
|
||||||
|
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||||
|
switch provider {
|
||||||
|
case providerArk:
|
||||||
|
cfg := &ark.EmbeddingConfig{
|
||||||
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||||
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||||
|
}
|
||||||
|
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||||
|
apiTypeVal := ark.APIType(apiType)
|
||||||
|
cfg.APIType = &apiTypeVal
|
||||||
|
}
|
||||||
|
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
||||||
|
case providerOpenai:
|
||||||
|
chatModelConfig := &openai.EmbeddingConfig{
|
||||||
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||||
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||||
|
}
|
||||||
|
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
||||||
|
case providerDashscope:
|
||||||
|
cfg := &dashscope.EmbeddingConfig{
|
||||||
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||||
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||||
|
}
|
||||||
|
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
||||||
|
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||||
|
switch provider {
|
||||||
|
case providerArk:
|
||||||
|
return EmbedderArk.EmbedStrings(ctx, texts)
|
||||||
|
case providerOpenai:
|
||||||
|
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
||||||
|
case providerDashscope:
|
||||||
|
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
||||||
|
}
|
||||||
47
common/eino/embedding_batch.go
Normal file
47
common/eino/embedding_batch.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/components/embedding"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BatchEmbedder 包装器,支持批量限制
|
||||||
|
type BatchEmbedder struct {
|
||||||
|
embedder embedding.Embedder
|
||||||
|
batchSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBatchEmbedder 创建支持批量限制的 embedding 包装器
|
||||||
|
func NewBatchEmbedder(embedder embedding.Embedder, batchSize int) *BatchEmbedder {
|
||||||
|
if batchSize <= 0 {
|
||||||
|
batchSize = 10 // 默认每批 10 个
|
||||||
|
}
|
||||||
|
return &BatchEmbedder{
|
||||||
|
embedder: embedder,
|
||||||
|
batchSize: batchSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbedStrings 分批调用 embedding
|
||||||
|
func (b *BatchEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||||
|
if len(texts) <= b.batchSize {
|
||||||
|
return b.embedder.EmbedStrings(ctx, texts, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allEmbeddings [][]float64
|
||||||
|
for i := 0; i < len(texts); i += b.batchSize {
|
||||||
|
end := i + b.batchSize
|
||||||
|
if end > len(texts) {
|
||||||
|
end = len(texts)
|
||||||
|
}
|
||||||
|
|
||||||
|
batch := texts[i:end]
|
||||||
|
embeddings, err := b.embedder.EmbedStrings(ctx, batch, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
allEmbeddings = append(allEmbeddings, embeddings...)
|
||||||
|
}
|
||||||
|
return allEmbeddings, nil
|
||||||
|
}
|
||||||
273
common/eino/embedding_qwen.go
Normal file
273
common/eino/embedding_qwen.go
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2024 Red Future Authors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/callbacks"
|
||||||
|
"github.com/cloudwego/eino/components"
|
||||||
|
"github.com/cloudwego/eino/components/embedding"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/net/gclient"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// 千问API默认配置
|
||||||
|
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||||
|
defaultTimeout = 10 * time.Minute
|
||||||
|
defaultRetryTimes = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type QwenEmbeddingConfig struct {
|
||||||
|
// Timeout specifies the maximum duration to wait for API responses
|
||||||
|
// Optional. Default: 10 minutes
|
||||||
|
Timeout *time.Duration `json:"timeout"`
|
||||||
|
|
||||||
|
// HTTPClient specifies the client to send HTTP requests.
|
||||||
|
// Optional. Default &http.Client{Timeout: Timeout}
|
||||||
|
HTTPClient *http.Client `json:"http_client"`
|
||||||
|
|
||||||
|
// RetryTimes specifies the number of retry attempts for failed API calls
|
||||||
|
// Optional. Default: 2
|
||||||
|
RetryTimes *int `json:"retry_times"`
|
||||||
|
|
||||||
|
// BaseURL specifies the base URL for Qwen DashScope service
|
||||||
|
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
|
||||||
|
// APIKey specifies the API Key for authentication
|
||||||
|
// Required
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
|
||||||
|
// Model specifies the model name for Qwen embedding
|
||||||
|
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
|
||||||
|
Model string `json:"model"`
|
||||||
|
|
||||||
|
// TextType specifies the type of text: "document" or "query"
|
||||||
|
// Optional. Default: "document"
|
||||||
|
TextType string `json:"text_type"`
|
||||||
|
|
||||||
|
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
|
||||||
|
// Optional. Default: 5
|
||||||
|
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QwenEmbedder struct {
|
||||||
|
client *gclient.Client
|
||||||
|
conf *QwenEmbeddingConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingRequest 千问embedding请求结构
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input struct {
|
||||||
|
Texts []string `json:"texts"`
|
||||||
|
} `json:"input"`
|
||||||
|
Parameters struct {
|
||||||
|
TextType string `json:"text_type,omitempty"`
|
||||||
|
} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingResponse 千问embedding响应结构
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Output struct {
|
||||||
|
Embeddings []struct {
|
||||||
|
TextIndex int `json:"text_index"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
} `json:"embeddings"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage struct {
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type APIError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RequestID string `json:"request_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *APIError) Error() string {
|
||||||
|
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
|
||||||
|
if len(config.BaseURL) == 0 {
|
||||||
|
config.BaseURL = defaultBaseURL
|
||||||
|
}
|
||||||
|
if config.Timeout == nil {
|
||||||
|
config.Timeout = &defaultTimeout
|
||||||
|
}
|
||||||
|
if config.RetryTimes == nil {
|
||||||
|
defaultRetryTimes := 2
|
||||||
|
config.RetryTimes = &defaultRetryTimes
|
||||||
|
}
|
||||||
|
if len(config.TextType) == 0 {
|
||||||
|
config.TextType = "document"
|
||||||
|
}
|
||||||
|
if config.MaxConcurrentRequests == nil {
|
||||||
|
defaultMaxConcurrentRequests := 5
|
||||||
|
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
|
||||||
|
}
|
||||||
|
|
||||||
|
client := g.Client()
|
||||||
|
client.SetTimeout(*config.Timeout)
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
|
||||||
|
if len(config.APIKey) == 0 {
|
||||||
|
return nil, fmt.Errorf("[Qwen] APIKey is required")
|
||||||
|
}
|
||||||
|
if len(config.Model) == 0 {
|
||||||
|
return nil, fmt.Errorf("[Qwen] Model is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := buildQwenClient(config)
|
||||||
|
|
||||||
|
return &QwenEmbedder{
|
||||||
|
client: client,
|
||||||
|
conf: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
|
||||||
|
[][]float64, error) {
|
||||||
|
|
||||||
|
if len(texts) == 0 {
|
||||||
|
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
options := embedding.GetCommonOptions(&embedding.Options{
|
||||||
|
Model: &e.conf.Model,
|
||||||
|
}, opts...)
|
||||||
|
|
||||||
|
conf := &embedding.Config{
|
||||||
|
Model: dereferenceOrZero(options.Model),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
|
||||||
|
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
|
||||||
|
Texts: texts,
|
||||||
|
Config: conf,
|
||||||
|
})
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var usage *embedding.TokenUsage
|
||||||
|
var embeddings [][]float64
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 调用千问API获取embedding
|
||||||
|
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
|
||||||
|
if err != nil {
|
||||||
|
callbacks.OnError(ctx, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
|
||||||
|
Embeddings: embeddings,
|
||||||
|
Config: conf,
|
||||||
|
TokenUsage: usage,
|
||||||
|
})
|
||||||
|
|
||||||
|
return embeddings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
|
||||||
|
// 构建请求
|
||||||
|
var req EmbeddingRequest
|
||||||
|
req.Model = e.conf.Model
|
||||||
|
req.Input.Texts = texts
|
||||||
|
req.Parameters.TextType = e.conf.TextType
|
||||||
|
|
||||||
|
// 调用API
|
||||||
|
client := e.client.Clone()
|
||||||
|
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
|
||||||
|
client.SetHeader("Content-Type", "application/json")
|
||||||
|
client.SetTimeout(*e.conf.Timeout)
|
||||||
|
|
||||||
|
resp, err := client.Post(ctx, e.conf.BaseURL, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Close()
|
||||||
|
|
||||||
|
// 检查状态码
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp APIError
|
||||||
|
result := resp.ReadAll()
|
||||||
|
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
|
||||||
|
return nil, nil, &errResp
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var apiResp EmbeddingResponse
|
||||||
|
result := resp.ReadAll()
|
||||||
|
if err = gconv.Struct(result, &apiResp); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应结果
|
||||||
|
embeddings := make([][]float64, len(texts))
|
||||||
|
for _, emb := range apiResp.Output.Embeddings {
|
||||||
|
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
|
||||||
|
embeddings[emb.TextIndex] = emb.Embedding
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &embedding.TokenUsage{
|
||||||
|
TotalTokens: apiResp.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
|
||||||
|
|
||||||
|
return embeddings, usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *QwenEmbedder) GetType() string {
|
||||||
|
return getType()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func getType() string {
|
||||||
|
return "Qwen"
|
||||||
|
}
|
||||||
|
|
||||||
|
func dereferenceOrZero[T any](v *T) T {
|
||||||
|
if v == nil {
|
||||||
|
var t T
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return *v
|
||||||
|
}
|
||||||
11
common/eino/priority_enum.go
Normal file
11
common/eino/priority_enum.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
// TaskPriority 任务优先级
|
||||||
|
type TaskPriority string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskPriorityLow TaskPriority = "low" // 低优先级
|
||||||
|
TaskPriorityMedium TaskPriority = "medium" // 中优先级
|
||||||
|
TaskPriorityHigh TaskPriority = "high" // 高优先级
|
||||||
|
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
|
||||||
|
)
|
||||||
12
common/eino/status_enum.go
Normal file
12
common/eino/status_enum.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
// TaskStatus 任务状态
|
||||||
|
type TaskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskStatusPending TaskStatus = "pending" // 待处理
|
||||||
|
TaskStatusRunning TaskStatus = "running" // 运行中
|
||||||
|
TaskStatusCompleted TaskStatus = "completed" // 已完成
|
||||||
|
TaskStatusFailed TaskStatus = "failed" // 失败
|
||||||
|
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
|
||||||
|
)
|
||||||
14
common/eino/task_type.go
Normal file
14
common/eino/task_type.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
// TaskType 任务类型
|
||||||
|
type TaskType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
|
||||||
|
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
|
||||||
|
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
|
||||||
|
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
|
||||||
|
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
|
||||||
|
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
|
||||||
|
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
|
||||||
|
)
|
||||||
114
common/gse/utils.go
Normal file
114
common/gse/utils.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package gse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/go-ego/gse"
|
||||||
|
"github.com/go-ego/gse/hmm/extracker"
|
||||||
|
"github.com/go-ego/gse/hmm/segment"
|
||||||
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var GseTool *gseTool
|
||||||
|
|
||||||
|
// 初始化函数:程序启动时执行一次
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
GseTool, err = newGseTool()
|
||||||
|
if err != nil {
|
||||||
|
glog.Error(context.Background(), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// gseTool 关键词提取工具(gse v1.0.2 标准)
|
||||||
|
type gseTool struct {
|
||||||
|
seg gse.Segmenter
|
||||||
|
tfidf *extracker.TagExtracter
|
||||||
|
tr *extracker.TextRanker
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGseTool 初始化工具(内置词典 + 停用词)
|
||||||
|
func newGseTool() (tool *gseTool, err error) {
|
||||||
|
// 1. 初始化分词器
|
||||||
|
var seg gse.Segmenter
|
||||||
|
// 内置词典(无外部文件)
|
||||||
|
err = seg.LoadDictEmbed()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 内置停用词(v1.0.2 标准)
|
||||||
|
err = seg.LoadStopEmbed()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 初始化 TF-IDF 提取器
|
||||||
|
tfidf := &extracker.TagExtracter{}
|
||||||
|
tfidf.WithGse(seg)
|
||||||
|
err = tfidf.LoadIdf()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 初始化 TextRank 提取器
|
||||||
|
tr := &extracker.TextRanker{}
|
||||||
|
tr.WithGse(seg)
|
||||||
|
|
||||||
|
tool = &gseTool{
|
||||||
|
seg: seg,
|
||||||
|
tfidf: tfidf,
|
||||||
|
tr: tr,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM)
|
||||||
|
func (k *gseTool) Cut(text string) []string {
|
||||||
|
return k.seg.Cut(text, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keyword 最终输出:关键词 + 权重
|
||||||
|
type Keyword struct {
|
||||||
|
Word string `json:"word"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *gseTool) Extract(text string, topN int) []Keyword {
|
||||||
|
// 1. 提取 TF-IDF
|
||||||
|
tfTags := k.extractTFIDF(text, topN)
|
||||||
|
|
||||||
|
// 2. 提取 TextRank
|
||||||
|
trTags := k.extractTextRank(text, topN)
|
||||||
|
|
||||||
|
// 3. 合并成最终关键词(业务最常用)
|
||||||
|
scoreMap := make(map[string]float64)
|
||||||
|
for _, tag := range tfTags {
|
||||||
|
scoreMap[tag.Text] = tag.Weight
|
||||||
|
}
|
||||||
|
for _, tag := range trTags {
|
||||||
|
scoreMap[tag.Text] = tag.Weight
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转成切片并排序(高分在前)
|
||||||
|
res := make([]Keyword, 0, len(scoreMap))
|
||||||
|
for word, score := range scoreMap {
|
||||||
|
res = append(res, Keyword{Word: word, Score: score})
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(res, func(i, j int) bool {
|
||||||
|
return res[i].Score > res[j].Score
|
||||||
|
})
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTFIDF TF-IDF 关键词(带权重)90% 业务:文章标签、搜索、关键词
|
||||||
|
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
|
||||||
|
return k.tfidf.ExtractTags(text, topN)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
|
||||||
|
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
|
||||||
|
return k.tr.TextRank(text, topN)
|
||||||
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package controller
|
|
||||||
|
|
||||||
type datasetIndex struct{}
|
|
||||||
|
|
||||||
var DatasetIndex = new(datasetIndex)
|
|
||||||
32
go.mod
32
go.mod
@@ -3,15 +3,29 @@ module rag
|
|||||||
go 1.26.0
|
go 1.26.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
gitea.com/red-future/common v0.0.6
|
gitea.com/red-future/common v0.0.11
|
||||||
github.com/bjang03/gmq v0.0.0-00010101000000-000000000000
|
github.com/bjang03/gmq v0.0.0-00010101000000-000000000000
|
||||||
github.com/cloudwego/eino v0.8.6
|
github.com/cloudwego/eino v0.8.6
|
||||||
|
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1
|
||||||
|
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
|
||||||
|
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||||
|
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||||
|
github.com/elastic/go-elasticsearch/v8 v8.16.0
|
||||||
|
github.com/go-ego/gse v1.0.2
|
||||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||||
github.com/gogf/gf/v2 v2.10.0
|
github.com/gogf/gf/v2 v2.10.0
|
||||||
|
github.com/golang/glog v1.2.5
|
||||||
github.com/pgvector/pgvector-go v0.3.0
|
github.com/pgvector/pgvector-go v0.3.0
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gitea.com/red-future/common v0.0.6 => ../common
|
replace gitea.com/red-future/common v0.0.11 => ../common
|
||||||
|
|
||||||
replace github.com/bjang03/gmq => ../gmq
|
replace github.com/bjang03/gmq => ../gmq
|
||||||
|
|
||||||
@@ -35,18 +49,7 @@ require (
|
|||||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
||||||
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect
|
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
|
||||||
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
||||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||||
@@ -56,13 +59,11 @@ require (
|
|||||||
github.com/eino-contrib/docx2md v0.0.1 // indirect
|
github.com/eino-contrib/docx2md v0.0.1 // indirect
|
||||||
github.com/eino-contrib/jsonschema v1.0.3 // indirect
|
github.com/eino-contrib/jsonschema v1.0.3 // indirect
|
||||||
github.com/elastic/elastic-transport-go/v8 v8.10.0 // indirect
|
github.com/elastic/elastic-transport-go/v8 v8.10.0 // indirect
|
||||||
github.com/elastic/go-elasticsearch/v8 v8.16.0 // indirect
|
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
|
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
|
||||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||||
github.com/fatih/color v1.19.0 // indirect
|
github.com/fatih/color v1.19.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||||
github.com/go-ego/gse v1.0.2 // indirect
|
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
@@ -74,7 +75,6 @@ require (
|
|||||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 // indirect
|
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||||
github.com/golang/glog v1.2.5 // indirect
|
|
||||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||||
github.com/golang/protobuf v1.5.4 // indirect
|
github.com/golang/protobuf v1.5.4 // indirect
|
||||||
github.com/golang/snappy v1.0.0 // indirect
|
github.com/golang/snappy v1.0.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
|
|||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
||||||
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
||||||
|
gitea.com/red-future/common v0.0.11 h1:AV7W3G0uZ8aPpHHSHd4ZHmLWe5+2STPKe/AYPoPCWVc=
|
||||||
|
gitea.com/red-future/common v0.0.11/go.mod h1:B8syUI4XbLCDQSeRHURYxEwnWw8mEFgmqCxjC+lM+NU=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
package dto
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
var DatasetIndex = new(datasetIndexService)
|
|
||||||
|
|
||||||
type datasetIndexService struct{}
|
|
||||||
@@ -3,6 +3,8 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"rag/common/eino"
|
||||||
|
"rag/common/gse"
|
||||||
"rag/consts/document"
|
"rag/consts/document"
|
||||||
"rag/consts/public"
|
"rag/consts/public"
|
||||||
"rag/dao"
|
"rag/dao"
|
||||||
@@ -16,8 +18,6 @@ import (
|
|||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.com/red-future/common/http"
|
||||||
"gitea.com/red-future/common/rag/eino"
|
|
||||||
"gitea.com/red-future/common/rag/gse"
|
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
gmq "github.com/bjang03/gmq/core/gmq"
|
gmq "github.com/bjang03/gmq/core/gmq"
|
||||||
"github.com/bjang03/gmq/mq"
|
"github.com/bjang03/gmq/mq"
|
||||||
@@ -251,7 +251,7 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 3. 组装向量文档
|
// 3. 组装向量文档
|
||||||
var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0)
|
var docsChunk = make([]*schema.Document, 0)
|
||||||
for i, t := range docsSplit {
|
for i, t := range docsSplit {
|
||||||
contentHash := gmd5.MustEncryptString(t.Content)
|
contentHash := gmd5.MustEncryptString(t.Content)
|
||||||
// 检查是否重复
|
// 检查是否重复
|
||||||
@@ -263,27 +263,26 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
|||||||
if !success {
|
if !success {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{
|
var metaData = make(map[string]any)
|
||||||
TenantId: doc.TenantId,
|
metaData[entity.DocumentCol.TenantId] = doc.TenantId
|
||||||
Creator: doc.Creator,
|
metaData[entity.DocumentCol.Creator] = doc.Creator
|
||||||
DatasetId: doc.DatasetId,
|
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
|
||||||
DocumentId: doc.Id,
|
metaData[entity.DocumentChunkCol.DocumentId] = doc.Id
|
||||||
Content: t.Content,
|
metaData[entity.DocumentChunkCol.ContentHash] = contentHash
|
||||||
ContentHash: contentHash,
|
metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i)
|
||||||
ChunkIndex: gconv.Int64(i),
|
t.MetaData = metaData
|
||||||
})
|
docsChunk = append(docsChunk, t)
|
||||||
|
|
||||||
}
|
}
|
||||||
// 4. 发送消息到队列
|
// 4. 发送消息到队列
|
||||||
if len(vectorDocs) > 0 {
|
if len(docsChunk) > 0 {
|
||||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||||
PubMessage: types.PubMessage{
|
PubMessage: types.PubMessage{
|
||||||
Topic: public.KnowledgeDocumentChunkTopic,
|
Topic: public.KnowledgeDocumentChunkTopic,
|
||||||
Data: vectorDocs,
|
Data: docsChunk,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
vectorDocsCount = gconv.Int64(len(vectorDocs))
|
vectorDocsCount = gconv.Int64(len(docsChunk))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,12 +317,12 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
|||||||
}
|
}
|
||||||
// 构建Meilisearch文档
|
// 构建Meilisearch文档
|
||||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||||
"id": contentHash,
|
entity.DocumentChunkCol.Id: contentHash,
|
||||||
"datasetId": doc.DatasetId,
|
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
||||||
"documentId": doc.Id,
|
entity.DocumentChunkCol.DocumentId: doc.Id,
|
||||||
"content": t.Content,
|
entity.DocumentChunkCol.Content: t.Content,
|
||||||
"contentHash": contentHash,
|
entity.DocumentChunkCol.ContentHash: contentHash,
|
||||||
"chunkIndex": i,
|
entity.DocumentChunkCol.ChunkIndex: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// 4. 写入到meilisearch数据库中
|
// 4. 写入到meilisearch数据库中
|
||||||
|
|||||||
@@ -2,23 +2,20 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"rag/common/eino"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"rag/consts/document"
|
"rag/consts/document"
|
||||||
"rag/consts/public"
|
"rag/consts/public"
|
||||||
"rag/dao"
|
"rag/dao"
|
||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
"gitea.com/red-future/common/rag/eino"
|
|
||||||
gmq "github.com/bjang03/gmq/core/gmq"
|
gmq "github.com/bjang03/gmq/core/gmq"
|
||||||
"github.com/bjang03/gmq/mq"
|
"github.com/bjang03/gmq/mq"
|
||||||
"github.com/bjang03/gmq/types"
|
"github.com/bjang03/gmq/types"
|
||||||
|
"github.com/cloudwego/eino/components/indexer"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
"github.com/pgvector/pgvector-go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var DocumentChunk = new(documentChunkService)
|
var DocumentChunk = new(documentChunkService)
|
||||||
@@ -49,114 +46,124 @@ func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentCh
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
||||||
var req = make([]*dto.VectorDocumentChunkMsg, 0)
|
var docs = make([]*schema.Document, 0)
|
||||||
msgMap := gconv.Map(msg)
|
msgMap := gconv.Map(msg)
|
||||||
if err = gconv.Structs(msgMap["data"], &req); err != nil {
|
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
|
||||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(req) == 0 {
|
if len(docs) == 0 {
|
||||||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
//ctx = context.WithValue(ctx, "user", &beans.User{
|
||||||
TenantId: req[0].TenantId,
|
// TenantId: req[0].TenantId,
|
||||||
UserName: req[0].Creator,
|
// UserName: req[0].Creator,
|
||||||
})
|
//})
|
||||||
|
|
||||||
// 调用eino接口获取向量
|
// 调用eino接口获取向量
|
||||||
var vectorDocsStr = make([]string, 0, len(req))
|
//var vectorDocsStr = make([]string, 0, len(req))
|
||||||
for _, t := range req {
|
//for _, t := range req {
|
||||||
vectorDocsStr = append(vectorDocsStr, t.Content)
|
// vectorDocsStr = append(vectorDocsStr, t.Content)
|
||||||
}
|
//}
|
||||||
embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr)
|
//embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
// g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
// err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||||
return
|
// return
|
||||||
}
|
//}
|
||||||
|
|
||||||
// 获取向量维度
|
// 获取向量维度
|
||||||
dimension := 0
|
//dimension := 0
|
||||||
if len(embeddings) > 0 {
|
//if len(embeddings) > 0 {
|
||||||
dimension = len(embeddings[0])
|
// dimension = len(embeddings[0])
|
||||||
}
|
//}
|
||||||
|
|
||||||
// 创建或更新DatasetIndex
|
// 创建或更新DatasetIndex
|
||||||
err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req)))
|
//err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req)))
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err)
|
// g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err)
|
||||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
// err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||||
return
|
// return
|
||||||
}
|
//}
|
||||||
|
|
||||||
// 更新向量文档
|
// 更新向量文档
|
||||||
for i, embedding := range embeddings {
|
//for i, embedding := range embeddings {
|
||||||
req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding))
|
// req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding))
|
||||||
req[i].VectorStatus = document.VectorStatusCompleted.Code()
|
// req[i].VectorStatus = document.VectorStatusCompleted.Code()
|
||||||
req[i].Status = document.StatusEnable.Code()
|
// req[i].Status = document.StatusEnable.Code()
|
||||||
}
|
//}
|
||||||
_, err = dao.DocumentChunk.BatchInsert(ctx, req)
|
//_, err = dao.DocumentChunk.BatchInsert(ctx, req)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
// g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
// err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||||
|
// return
|
||||||
|
//}
|
||||||
|
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||||
|
BatchSize: 10,
|
||||||
|
})
|
||||||
|
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||||
|
if err != nil || rows == 0 {
|
||||||
|
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
tenantId := docs[0].MetaData[entity.DocumentChunkCol.TenantId].(uint64)
|
||||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusCompleted.Code())
|
creator := docs[0].MetaData[entity.DocumentChunkCol.Creator].(string)
|
||||||
|
documentId := docs[0].MetaData[entity.DocumentChunkCol.DocumentId].(int64)
|
||||||
|
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// createOrUpdateDatasetIndex 创建或更新数据集索引
|
//// createOrUpdateDatasetIndex 创建或更新数据集索引
|
||||||
func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) {
|
//func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) {
|
||||||
// 查询数据集是否已有索引
|
// // 查询数据集是否已有索引
|
||||||
existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
// existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
// if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
// 已有索引 → 只更新数量
|
// // 已有索引 → 只更新数量
|
||||||
if existIndex != nil {
|
// if existIndex != nil {
|
||||||
_ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount)
|
// _ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount)
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
// ====================== 创建新索引 ======================
|
// // ====================== 创建新索引 ======================
|
||||||
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名
|
// indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名
|
||||||
// 1. 插入索引配置
|
// // 1. 插入索引配置
|
||||||
index := &entity.DatasetIndex{
|
// index := &entity.DatasetIndex{
|
||||||
DatasetId: datasetId,
|
// DatasetId: datasetId,
|
||||||
Name: indexName,
|
// Name: indexName,
|
||||||
Dimension: dimension,
|
// Dimension: dimension,
|
||||||
FieldType: "float",
|
// FieldType: "float",
|
||||||
MetricType: "COSINE",
|
// MetricType: "COSINE",
|
||||||
Status: gconv.PtrInt8(1),
|
// Status: gconv.PtrInt8(1),
|
||||||
VectorCount: vectorCount,
|
// VectorCount: vectorCount,
|
||||||
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
// Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||||||
}
|
// }
|
||||||
_, err = dao.DatasetIndex.Insert(ctx, index)
|
// _, err = dao.DatasetIndex.Insert(ctx, index)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
// 2. 真正创建 PGVector 索引(唯一真实索引!)
|
// // 2. 真正创建 PGVector 索引(唯一真实索引!)
|
||||||
err = s.createRealPGVectorIndex(ctx, indexName)
|
// err = s.createRealPGVectorIndex(ctx, indexName)
|
||||||
return err
|
// return err
|
||||||
}
|
//}
|
||||||
|
//
|
||||||
// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用)
|
//// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用)
|
||||||
func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
//func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||||||
// 执行真实建索引语句
|
// // 执行真实建索引语句
|
||||||
err := dao.DatasetIndex.InsertIndex(ctx, indexName)
|
// err := dao.DatasetIndex.InsertIndex(ctx, indexName)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
g.Log().Error(ctx, "创建向量索引失败:", err)
|
// g.Log().Error(ctx, "创建向量索引失败:", err)
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName)
|
// g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName)
|
||||||
return nil
|
// return nil
|
||||||
}
|
//}
|
||||||
|
|
||||||
// publishKnowledgeDocumentMsg 发布消息
|
// publishKnowledgeDocumentMsg 发布消息
|
||||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user