149 lines
3.6 KiB
Go
149 lines
3.6 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"rag/consts/document"
|
|
"rag/consts/public"
|
|
"rag/consts/task"
|
|
"rag/dao"
|
|
"rag/model/dto"
|
|
|
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
|
"github.com/gogf/gf/v2/database/gdb"
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
)
|
|
|
|
var Task = new(taskService)
|
|
|
|
type taskService struct{}
|
|
|
|
// WriteTaskProgress 写入任务进度(核心方法)
|
|
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
|
|
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
|
TaskId: req.TaskId,
|
|
})
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
|
return err
|
|
}
|
|
completed := false
|
|
if total != 0 {
|
|
taskVO := make([]*dto.TaskVO, 0, total)
|
|
err = gconv.Struct(t, &taskVO)
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
|
return err
|
|
}
|
|
taskVO = append(taskVO, &dto.TaskVO{
|
|
TaskType: req.TaskType,
|
|
Status: req.Status,
|
|
})
|
|
completed = IsAllSubTasks(taskVO, task.TaskStatusCompleted)
|
|
}
|
|
|
|
// 1. 查询是否已存在该文档的该类型任务
|
|
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
|
TaskId: req.TaskId,
|
|
TaskType: req.TaskType,
|
|
})
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
|
return err
|
|
}
|
|
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
|
// 2. 如果不存在,则创建新任务
|
|
if g.IsEmpty(existTask) {
|
|
createReq := &dto.CreateTaskReq{
|
|
TaskId: req.TaskId,
|
|
TaskType: req.TaskType,
|
|
Status: req.Status,
|
|
Remark: req.Remark,
|
|
}
|
|
_, err = dao.Task.Insert(ctx, createReq)
|
|
} else {
|
|
// 3. 如果已存在,则更新任务
|
|
updateReq := &dto.UpdateTaskReq{
|
|
Id: existTask[0].Id,
|
|
Status: req.Status,
|
|
Remark: req.Remark,
|
|
}
|
|
_, err = dao.Task.Update(ctx, updateReq)
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
if completed {
|
|
// 3. 如果已存在,则更新任务
|
|
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
|
TaskId: req.TaskId,
|
|
Status: task.TaskStatusCompleted,
|
|
Remark: "文档解析完成",
|
|
})
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
|
return err
|
|
}
|
|
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
|
Id: req.TaskId,
|
|
VectorStatus: document.VectorStatusCompleted.Code(),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if task.TaskStatusFailed == req.Status {
|
|
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
|
Id: req.TaskId,
|
|
VectorStatus: document.VectorStatusFailed.Code(),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
// IsAllSubTasks 判断三个子任务
|
|
func IsAllSubTasks(subTasks []*dto.TaskVO, taskStatus task.TaskStatus) bool {
|
|
// 必须包含 3 种任务类型
|
|
hasKeywords := false
|
|
hasVector := false
|
|
hasFullText := false
|
|
|
|
for _, t := range subTasks {
|
|
// 子任务必须是【已完成】状态才计数
|
|
if t.Status == taskStatus {
|
|
switch t.TaskType {
|
|
case task.TaskTypeExtractKeywords:
|
|
hasKeywords = true
|
|
case task.TaskTypeGenerateVector:
|
|
hasVector = true
|
|
case task.TaskTypeFullTextSearch:
|
|
hasFullText = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// 三个任务全部完成 → 返回true
|
|
return hasKeywords && hasVector && hasFullText
|
|
}
|
|
|
|
func (s *taskService) Get(ctx context.Context, req *dto.GetTaskReq) (res *dto.ListTaskRes, err error) {
|
|
list, total, err := dao.Task.Get(ctx, req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
res = &dto.ListTaskRes{
|
|
Total: total,
|
|
}
|
|
err = gconv.Struct(list, &res.List)
|
|
return
|
|
}
|