feat: rag初始版
This commit is contained in:
94
common/eino/c.go
Normal file
94
common/eino/c.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
||||
)
|
||||
|
||||
func TestRetriever() {
|
||||
ctx := context.Background()
|
||||
|
||||
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
})
|
||||
|
||||
// 创建 retriever 组件
|
||||
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
||||
Client: client,
|
||||
Index: indexName,
|
||||
TopK: 5,
|
||||
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
||||
QueryFieldName: fieldContent,
|
||||
VectorFieldName: fieldContentVector,
|
||||
Hybrid: false,
|
||||
// RRF 仅在特定许可证下可用
|
||||
// 参见: https://www.elastic.co/subscriptions
|
||||
RRF: false,
|
||||
RRFRankConstant: nil,
|
||||
RRFWindowSize: nil,
|
||||
}),
|
||||
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
||||
doc = &schema.Document{
|
||||
ID: *hit.Id_,
|
||||
Content: "",
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
|
||||
var src map[string]any
|
||||
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for field, val := range src {
|
||||
switch field {
|
||||
case fieldContent:
|
||||
doc.Content = val.(string)
|
||||
case fieldContentVector:
|
||||
var v []float64
|
||||
for _, item := range val.([]interface{}) {
|
||||
v = append(v, item.(float64))
|
||||
}
|
||||
doc.WithDenseVector(v)
|
||||
case fieldExtraLocation:
|
||||
doc.MetaData[docExtraLocation] = val.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if hit.Score_ != nil {
|
||||
doc.WithScore(float64(*hit.Score_))
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
},
|
||||
Embedding: EmbedderDashscope,
|
||||
})
|
||||
|
||||
// 不带过滤器的搜索
|
||||
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
||||
|
||||
// 带过滤器的搜索
|
||||
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
||||
es8.WithFilters([]types.Query{{
|
||||
Term: map[string]types.TermQuery{
|
||||
fieldExtraLocation: {
|
||||
CaseInsensitive: of(true),
|
||||
Value: "China",
|
||||
},
|
||||
},
|
||||
}}),
|
||||
)
|
||||
|
||||
fmt.Printf("retrieved docs: %+v\n", docs)
|
||||
}
|
||||
|
||||
func of[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
Reference in New Issue
Block a user