feat: 重构知识库文档导入任务化流程

- 新增上传建单、异步解析、分块处理与异步向量化闭环

- 收口分享页权限、完成态检索过滤与 SSE 局部状态刷新
This commit is contained in:
2026-04-15 19:27:22 +08:00
parent a41b50959e
commit 2689adfa40
56 changed files with 6376 additions and 1060 deletions

View File

@@ -103,6 +103,10 @@
<groupId>tech.easyflow</groupId>
<artifactId>easyflow-common-chat-protocol</artifactId>
</dependency>
<dependency>
<groupId>tech.easyflow</groupId>
<artifactId>easyflow-common-mq</artifactId>
</dependency>
<dependency>
<groupId>com.easyagents</groupId>

View File

@@ -12,8 +12,9 @@ public class ThreadPoolConfig {
private static final Logger log = LoggerFactory.getLogger(ThreadPoolConfig.class);
/**
* SSE消息发送专用线程池
* 核心原则IO密集型任务网络推送线程数 = CPU核心数 * 2 + 1
* 创建 SSE 消息发送线程池
*
* @return SSE 推送线程池
*/
@Bean(name = "sseThreadPool")
public ThreadPoolTaskExecutor sseThreadPool() {
@@ -37,4 +38,29 @@ public class ThreadPoolConfig {
executor.initialize();
return executor;
}
/**
* 创建知识库文档导入任务线程池。
*
* @return 文档导入任务线程池
*/
@Bean(name = "documentImportTaskExecutor")
public ThreadPoolTaskExecutor documentImportTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
int cpuCoreNum = Runtime.getRuntime().availableProcessors();
executor.setCorePoolSize(Math.max(2, cpuCoreNum));
executor.setMaxPoolSize(Math.max(4, cpuCoreNum * 2));
executor.setQueueCapacity(200);
executor.setKeepAliveSeconds(60);
executor.setThreadNamePrefix("document-import-");
executor.setRejectedExecutionHandler((runnable, executorService) -> {
log.error("文档导入线程池过载!核心线程数:{},最大线程数:{},队列任务数:{}",
executorService.getCorePoolSize(),
executorService.getMaximumPoolSize(),
executorService.getQueue().size());
throw new BusinessException("文档导入任务繁忙,请稍后重试");
});
executor.initialize();
return executor;
}
}

View File

@@ -5,6 +5,8 @@ import com.easyagents.document.core.model.ParseResponse;
import com.easyagents.document.core.model.ParseResult;
import com.easyagents.document.core.model.ParseTaskInfo;
import com.easyagents.document.core.model.ParseTaskStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
@@ -30,6 +32,8 @@ import tech.easyflow.ai.utils.DocUtil;
@Service
public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeService {
private static final Logger LOG = LoggerFactory.getLogger(DocumentParseBridgeServiceImpl.class);
@Nullable
private final DocumentParseService documentParseService;
private final DocumentSourceLoader documentSourceLoader;
@@ -52,12 +56,21 @@ public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeServic
@Override
public DocumentParsedResult parse(DocumentSourceRef source, DocumentParseScenario scenario) {
try {
LoadedDocumentSource loadedSource = preparePdfSource(source);
LoadedDocumentSource loadedSource = prepareSupportedSource(source);
LOG.info("桥接服务开始同步解析文档: fileName={}, contentType={}, scenario={}",
loadedSource.getFileName(), loadedSource.getContentType(), scenario);
ParseResponse response = requireService().parse(parseRequestFactory.build(loadedSource, scenario));
return parseResultMapper.map(extractSingleResult(response, false));
DocumentParsedResult result = parseResultMapper.map(extractSingleResult(response, false));
LOG.info("桥接服务同步解析完成: fileName={}, scenario={}, preferredTextLength={}",
loadedSource.getFileName(), scenario, resolveTextLength(result));
return result;
} catch (DocumentParseBridgeException e) {
LOG.error("桥接服务同步解析失败: fileName={}, scenario={}",
source == null ? null : source.getFileName(), scenario, e);
throw e;
} catch (Exception e) {
LOG.error("桥接服务同步解析异常: fileName={}, scenario={}",
source == null ? null : source.getFileName(), scenario, e);
throw DocumentParseBridgeException.parseFailed("同步文档解析失败", e);
}
}
@@ -68,12 +81,21 @@ public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeServic
@Override
public DocumentParseTaskStatus submit(DocumentSourceRef source, DocumentParseScenario scenario) {
try {
LoadedDocumentSource loadedSource = preparePdfSource(source);
LoadedDocumentSource loadedSource = prepareSupportedSource(source);
LOG.info("桥接服务开始提交异步解析任务: fileName={}, contentType={}, scenario={}",
loadedSource.getFileName(), loadedSource.getContentType(), scenario);
ParseTaskStatus taskStatus = requireService().submit(parseRequestFactory.build(loadedSource, scenario));
return parseResultMapper.map(taskStatus);
DocumentParseTaskStatus mappedStatus = parseResultMapper.map(taskStatus);
LOG.info("桥接服务异步解析任务提交完成: fileName={}, scenario={}, providerTaskId={}, status={}",
loadedSource.getFileName(), scenario, mappedStatus.getTaskId(), mappedStatus.getStatus());
return mappedStatus;
} catch (DocumentParseBridgeException e) {
LOG.error("桥接服务提交异步解析任务失败: fileName={}, scenario={}",
source == null ? null : source.getFileName(), scenario, e);
throw e;
} catch (Exception e) {
LOG.error("桥接服务提交异步解析任务异常: fileName={}, scenario={}",
source == null ? null : source.getFileName(), scenario, e);
throw DocumentParseBridgeException.taskFailed("提交异步文档解析任务失败", e);
}
}
@@ -104,11 +126,17 @@ public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeServic
throw DocumentParseBridgeException.resultFetchFailed("taskId 不能为空");
}
try {
LOG.info("桥接服务开始获取异步解析结果: providerTaskId={}", taskId);
ParseResponse response = requireService().queryResult(taskId);
return parseResultMapper.map(extractSingleResult(response, true));
DocumentParsedResult result = parseResultMapper.map(extractSingleResult(response, true));
LOG.info("桥接服务获取异步解析结果完成: providerTaskId={}, preferredTextLength={}",
taskId, resolveTextLength(result));
return result;
} catch (DocumentParseBridgeException e) {
LOG.error("桥接服务获取异步解析结果失败: providerTaskId={}", taskId, e);
throw e;
} catch (Exception e) {
LOG.error("桥接服务获取异步解析结果异常: providerTaskId={}", taskId, e);
throw DocumentParseBridgeException.resultFetchFailed("获取异步文档解析结果失败", e);
}
}
@@ -123,14 +151,32 @@ public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeServic
}
try {
ParseTaskInfo taskInfo = requireService().queryTaskInfo(taskId);
return parseResultMapper.map(taskInfo);
DocumentParseTaskInfo mappedTaskInfo = parseResultMapper.map(taskInfo);
LOG.info("桥接服务查询异步解析任务状态: providerTaskId={}, status={}, hasResult={}",
taskId,
mappedTaskInfo == null ? null : mappedTaskInfo.getStatus(),
mappedTaskInfo != null && mappedTaskInfo.getResult() != null);
return mappedTaskInfo;
} catch (DocumentParseBridgeException e) {
LOG.error("桥接服务查询异步解析任务状态失败: providerTaskId={}", taskId, e);
throw e;
} catch (Exception e) {
LOG.error("桥接服务查询异步解析任务状态异常: providerTaskId={}", taskId, e);
throw DocumentParseBridgeException.taskFailed("聚合查询异步文档解析任务信息失败", e);
}
}
private int resolveTextLength(DocumentParsedResult result) {
String text = result == null ? null : result.getPreferredText();
if (!StringUtils.hasText(text) && result != null) {
text = result.getMarkdown();
}
if (!StringUtils.hasText(text) && result != null) {
text = result.getPlainText();
}
return text == null ? 0 : text.length();
}
private DocumentParseService requireService() {
if (documentParseService == null) {
throw DocumentParseBridgeException.serviceNotEnabled();
@@ -138,24 +184,32 @@ public class DocumentParseBridgeServiceImpl implements DocumentParseBridgeServic
return documentParseService;
}
private LoadedDocumentSource preparePdfSource(DocumentSourceRef source) {
private LoadedDocumentSource prepareSupportedSource(DocumentSourceRef source) {
LoadedDocumentSource loadedSource = documentSourceLoader.load(source);
if (!isPdf(loadedSource)) {
throw DocumentParseBridgeException.unsupportedSource("统一文档解析桥接首版仅支持 PDF 文件");
if (!isSupportedByBridge(loadedSource)) {
throw DocumentParseBridgeException.unsupportedSource("统一文档解析桥接当前仅支持 PDF、DOCX 文件");
}
return loadedSource;
}
private boolean isPdf(LoadedDocumentSource loadedSource) {
private boolean isSupportedByBridge(LoadedDocumentSource loadedSource) {
String contentType = loadedSource.getContentType();
if (StringUtils.hasText(contentType) && contentType.toLowerCase().contains("pdf")) {
return true;
if (StringUtils.hasText(contentType)) {
String normalizedContentType = contentType.toLowerCase();
if (normalizedContentType.contains("pdf")
|| normalizedContentType.contains("wordprocessingml.document")) {
return true;
}
}
String fileName = loadedSource.getFileName();
if (!StringUtils.hasText(fileName) || !fileName.contains(".")) {
return false;
}
return "pdf".equals(DocUtil.normalizeSuffix(DocUtil.getSuffix(fileName)));
String suffix = DocUtil.normalizeSuffix(DocUtil.getSuffix(fileName));
if ("pdf".equals(suffix) || "docx".equals(suffix)) {
return true;
}
return false;
}
private ParseResult extractSingleResult(ParseResponse response, boolean resultFetchPhase) {

View File

@@ -93,6 +93,7 @@ public final class DocumentImportDtos {
public static class PreviewRequest implements Serializable {
private BigInteger knowledgeId;
private BigInteger documentId;
private List<PreviewFileRequest> files = new ArrayList<PreviewFileRequest>();
public BigInteger getKnowledgeId() {
@@ -103,6 +104,14 @@ public final class DocumentImportDtos {
this.knowledgeId = knowledgeId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public List<PreviewFileRequest> getFiles() {
return files;
}
@@ -114,6 +123,7 @@ public final class DocumentImportDtos {
public static class CommitRequest implements Serializable {
private BigInteger knowledgeId;
private BigInteger documentId;
private List<String> previewSessionIds = new ArrayList<String>();
public BigInteger getKnowledgeId() {
@@ -124,6 +134,14 @@ public final class DocumentImportDtos {
this.knowledgeId = knowledgeId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public List<String> getPreviewSessionIds() {
return previewSessionIds;
}
@@ -241,16 +259,158 @@ public final class DocumentImportDtos {
}
}
public static class PreviewSourceRange implements Serializable {
private Integer start;
private Integer end;
public Integer getStart() {
return start;
}
public void setStart(Integer start) {
this.start = start;
}
public Integer getEnd() {
return end;
}
public void setEnd(Integer end) {
this.end = end;
}
}
public static class PreviewChunkResult implements Serializable {
private String answer;
private Integer charCount;
private String chunkId;
private String chunkType;
private String content;
private List<String> headingPath = new ArrayList<String>();
private Integer partNo;
private Integer partTotal;
private String question;
private String sourceLabel;
private Integer tokenEstimate;
private List<String> warnings = new ArrayList<String>();
private List<PreviewSourceRange> sourceRanges = new ArrayList<PreviewSourceRange>();
public String getAnswer() {
return answer;
}
public void setAnswer(String answer) {
this.answer = answer;
}
public Integer getCharCount() {
return charCount;
}
public void setCharCount(Integer charCount) {
this.charCount = charCount;
}
public String getChunkId() {
return chunkId;
}
public void setChunkId(String chunkId) {
this.chunkId = chunkId;
}
public String getChunkType() {
return chunkType;
}
public void setChunkType(String chunkType) {
this.chunkType = chunkType;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public List<String> getHeadingPath() {
return headingPath;
}
public void setHeadingPath(List<String> headingPath) {
this.headingPath = headingPath;
}
public Integer getPartNo() {
return partNo;
}
public void setPartNo(Integer partNo) {
this.partNo = partNo;
}
public Integer getPartTotal() {
return partTotal;
}
public void setPartTotal(Integer partTotal) {
this.partTotal = partTotal;
}
public String getQuestion() {
return question;
}
public void setQuestion(String question) {
this.question = question;
}
public String getSourceLabel() {
return sourceLabel;
}
public void setSourceLabel(String sourceLabel) {
this.sourceLabel = sourceLabel;
}
public Integer getTokenEstimate() {
return tokenEstimate;
}
public void setTokenEstimate(Integer tokenEstimate) {
this.tokenEstimate = tokenEstimate;
}
public List<String> getWarnings() {
return warnings;
}
public void setWarnings(List<String> warnings) {
this.warnings = warnings;
}
public List<PreviewSourceRange> getSourceRanges() {
return sourceRanges;
}
public void setSourceRanges(List<PreviewSourceRange> sourceRanges) {
this.sourceRanges = sourceRanges;
}
}
public static class PreviewFileResult implements Serializable {
private String previewSessionId;
private String filePath;
private String fileName;
private String normalizedContent;
private String strategyCode;
private String strategyLabel;
private AnalysisResult analysis;
private Integer totalChunks;
private Integer totalWarnings;
private List<RagChunk> chunks = new ArrayList<RagChunk>();
private List<PreviewChunkResult> chunks = new ArrayList<PreviewChunkResult>();
public String getPreviewSessionId() {
return previewSessionId;
@@ -276,6 +436,14 @@ public final class DocumentImportDtos {
this.fileName = fileName;
}
public String getNormalizedContent() {
return normalizedContent;
}
public void setNormalizedContent(String normalizedContent) {
this.normalizedContent = normalizedContent;
}
public String getStrategyCode() {
return strategyCode;
}
@@ -316,11 +484,11 @@ public final class DocumentImportDtos {
this.totalWarnings = totalWarnings;
}
public List<RagChunk> getChunks() {
public List<PreviewChunkResult> getChunks() {
return chunks;
}
public void setChunks(List<RagChunk> chunks) {
public void setChunks(List<PreviewChunkResult> chunks) {
this.chunks = chunks;
}
}
@@ -454,6 +622,7 @@ public final class DocumentImportDtos {
public static class PreviewSession implements Serializable {
private String sessionId;
private BigInteger knowledgeId;
private BigInteger documentId;
private String filePath;
private String fileName;
private String sourceFormat;
@@ -480,6 +649,14 @@ public final class DocumentImportDtos {
this.knowledgeId = knowledgeId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public String getFilePath() {
return filePath;
}
@@ -552,4 +729,265 @@ public final class DocumentImportDtos {
this.createdAt = createdAt;
}
}
public static class TaskCreateRequest implements Serializable {
private BigInteger knowledgeId;
private String filePath;
private String fileName;
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
}
public static class TaskCreateResponse implements Serializable {
private BigInteger documentId;
private BigInteger taskId;
private String processStatus;
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public BigInteger getTaskId() {
return taskId;
}
public void setTaskId(BigInteger taskId) {
this.taskId = taskId;
}
public String getProcessStatus() {
return processStatus;
}
public void setProcessStatus(String processStatus) {
this.processStatus = processStatus;
}
}
public static class TaskDetailResponse implements Serializable {
private BigInteger taskId;
private BigInteger documentId;
private BigInteger knowledgeId;
private String phase;
private String status;
private String processStatus;
private Integer progressPercent;
private Integer totalChunks;
private Integer completedChunks;
private Integer failedChunks;
private String providerTaskId;
private String errorSummary;
private Date startedAt;
private Date finishedAt;
public BigInteger getTaskId() {
return taskId;
}
public void setTaskId(BigInteger taskId) {
this.taskId = taskId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getPhase() {
return phase;
}
public void setPhase(String phase) {
this.phase = phase;
}
public String getStatus() {
return status;
}
public void setStatus(String status) {
this.status = status;
}
public String getProcessStatus() {
return processStatus;
}
public void setProcessStatus(String processStatus) {
this.processStatus = processStatus;
}
public Integer getProgressPercent() {
return progressPercent;
}
public void setProgressPercent(Integer progressPercent) {
this.progressPercent = progressPercent;
}
public Integer getTotalChunks() {
return totalChunks;
}
public void setTotalChunks(Integer totalChunks) {
this.totalChunks = totalChunks;
}
public Integer getCompletedChunks() {
return completedChunks;
}
public void setCompletedChunks(Integer completedChunks) {
this.completedChunks = completedChunks;
}
public Integer getFailedChunks() {
return failedChunks;
}
public void setFailedChunks(Integer failedChunks) {
this.failedChunks = failedChunks;
}
public String getProviderTaskId() {
return providerTaskId;
}
public void setProviderTaskId(String providerTaskId) {
this.providerTaskId = providerTaskId;
}
public String getErrorSummary() {
return errorSummary;
}
public void setErrorSummary(String errorSummary) {
this.errorSummary = errorSummary;
}
public Date getStartedAt() {
return startedAt;
}
public void setStartedAt(Date startedAt) {
this.startedAt = startedAt;
}
public Date getFinishedAt() {
return finishedAt;
}
public void setFinishedAt(Date finishedAt) {
this.finishedAt = finishedAt;
}
}
public static class TaskStartIndexRequest implements Serializable {
private BigInteger knowledgeId;
private BigInteger documentId;
private String previewSessionId;
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public String getPreviewSessionId() {
return previewSessionId;
}
public void setPreviewSessionId(String previewSessionId) {
this.previewSessionId = previewSessionId;
}
}
public static class TaskStartIndexResponse implements Serializable {
private BigInteger taskId;
private String processStatus;
public BigInteger getTaskId() {
return taskId;
}
public void setTaskId(BigInteger taskId) {
this.taskId = taskId;
}
public String getProcessStatus() {
return processStatus;
}
public void setProcessStatus(String processStatus) {
this.processStatus = processStatus;
}
}
public static class TaskRetryRequest implements Serializable {
private BigInteger knowledgeId;
private BigInteger documentId;
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
}
}

View File

@@ -18,4 +18,8 @@ public final class DocumentImportKeys {
public static final String KEY_DOCUMENT_ANALYSIS_SUMMARY = "splitter.analysisSummary";
public static final String KEY_DOCUMENT_SOURCE_FILE_EXT = "splitter.sourceFileExt";
public static final String KEY_DOCUMENT_PREVIEW_VERSION = "splitter.previewVersion";
public static final String KEY_DOCUMENT_PARSE_BACKEND = "parse.backend";
public static final String KEY_DOCUMENT_PARSE_METADATA = "parse.metadata";
public static final String KEY_DOCUMENT_PARSE_WARNINGS = "parse.warnings";
public static final String KEY_DOCUMENT_PROVIDER_TASK_ID = "parse.providerTaskId";
}

View File

@@ -0,0 +1,76 @@
package tech.easyflow.ai.documentimport.task;
import com.alibaba.fastjson2.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import tech.easyflow.common.mq.config.MQProperties;
import tech.easyflow.common.mq.core.MQConsumerHandler;
import tech.easyflow.common.mq.core.MQMessage;
import tech.easyflow.common.mq.core.MQSubscription;
import java.util.List;
/**
* 文档向量化任务消费者。
*
* @author Codex
* @since 2026-04-14
*/
@Component
public class DocumentImportIndexTaskConsumer implements MQConsumerHandler {
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportIndexTaskConsumer.class);
private final KnowledgeDocumentImportTaskAppService appService;
private final MQProperties mqProperties;
public DocumentImportIndexTaskConsumer(KnowledgeDocumentImportTaskAppService appService,
MQProperties mqProperties) {
this.appService = appService;
this.mqProperties = mqProperties;
}
@Override
public MQSubscription subscription() {
MQSubscription subscription = new MQSubscription();
subscription.setTopic(DocumentImportTaskMqConstants.INDEX_TOPIC);
subscription.setConsumerGroup(DocumentImportTaskMqConstants.INDEX_GROUP);
subscription.setShardCount(resolveShardCount());
return subscription;
}
@Override
public void handle(List<MQMessage> messages) {
LOG.info("文档向量化消费者收到消息批次: count={}", messages == null ? 0 : messages.size());
for (MQMessage message : messages) {
DocumentImportTaskMessage event = JSON.parseObject(message.getBody(), DocumentImportTaskMessage.class);
if (event == null || event.getTaskId() == null) {
LOG.warn("文档向量化消费者跳过非法消息: streamMessageId={}, messageId={}",
message == null ? null : message.getStreamMessageId(),
message == null ? null : message.getMessageId());
continue;
}
LOG.info("文档向量化消费者开始处理消息: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId());
try {
appService.handleIndexTask(event.getTaskId());
LOG.info("文档向量化消费者处理完成: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId());
} catch (Exception exception) {
LOG.error("文档向量化消费者处理失败: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId(), exception);
throw exception;
}
}
}
/**
* 向量化消费者需覆盖生产端的所有分片,避免消息落入未订阅分片。
*
* @return 当前 Redis Stream 分片数
*/
private int resolveShardCount() {
return Math.max(mqProperties.getRedis().getChatPersistShardCount(), 1);
}
}

View File

@@ -0,0 +1,52 @@
package tech.easyflow.ai.documentimport.task;
import com.alibaba.fastjson2.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import tech.easyflow.common.mq.core.MQMessage;
import tech.easyflow.common.mq.core.MQProducer;
import java.math.BigInteger;
import java.util.Date;
/**
* 文档向量化任务消息生产者。
*
* @author Codex
* @since 2026-04-14
*/
@Service
public class DocumentImportIndexTaskProducer {
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportIndexTaskProducer.class);
private final MQProducer mqProducer;
public DocumentImportIndexTaskProducer(MQProducer mqProducer) {
this.mqProducer = mqProducer;
}
/**
* 发送向量化任务消息。
*
* @param taskId 任务 ID
*/
public void send(BigInteger taskId) {
DocumentImportTaskMessage event = new DocumentImportTaskMessage();
event.setTaskId(taskId);
event.setOccurredAt(new Date());
MQMessage message = new MQMessage();
message.setMessageId("index-" + taskId);
message.setTopic(DocumentImportTaskMqConstants.INDEX_TOPIC);
message.setKey(String.valueOf(taskId));
message.setCreatedAt(event.getOccurredAt());
message.setBody(JSON.toJSONString(event));
LOG.info("准备投递文档向量化 MQ 消息: topic={}, taskId={}, messageId={}",
message.getTopic(), taskId, message.getMessageId());
String recordId = mqProducer.send(message);
LOG.info("文档向量化 MQ 消息投递完成: topic={}, taskId={}, messageId={}, recordId={}",
message.getTopic(), taskId, message.getMessageId(), recordId);
}
}

View File

@@ -0,0 +1,33 @@
package tech.easyflow.ai.documentimport.task;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
/**
* 知识库文档解析任务收敛器。
*
* <p>该调度器只负责轮询运行中的桥接解析任务,不承担提交任务职责。</p>
*
* @author Codex
* @since 2026-04-15
*/
@Component
public class DocumentImportParseMonitor {
private final KnowledgeDocumentImportTaskAppService appService;
public DocumentImportParseMonitor(KnowledgeDocumentImportTaskAppService appService) {
this.appService = appService;
}
/**
* 定时收敛运行中的桥接解析任务状态。
*/
@Scheduled(
fixedDelayString = "${easyflow.ai.document-import.parse-monitor.fixed-delay:3000}",
initialDelayString = "${easyflow.ai.document-import.parse-monitor.initial-delay:5000}"
)
public void reconcileRunningParseTasks() {
appService.monitorRunningParseTasks();
}
}

View File

@@ -0,0 +1,76 @@
package tech.easyflow.ai.documentimport.task;
import com.alibaba.fastjson2.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import tech.easyflow.common.mq.config.MQProperties;
import tech.easyflow.common.mq.core.MQConsumerHandler;
import tech.easyflow.common.mq.core.MQMessage;
import tech.easyflow.common.mq.core.MQSubscription;
import java.util.List;
/**
* 文档解析任务消费者。
*
* @author Codex
* @since 2026-04-14
*/
@Component
public class DocumentImportParseTaskConsumer implements MQConsumerHandler {
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportParseTaskConsumer.class);
private final KnowledgeDocumentImportTaskAppService appService;
private final MQProperties mqProperties;
public DocumentImportParseTaskConsumer(KnowledgeDocumentImportTaskAppService appService,
MQProperties mqProperties) {
this.appService = appService;
this.mqProperties = mqProperties;
}
@Override
public MQSubscription subscription() {
MQSubscription subscription = new MQSubscription();
subscription.setTopic(DocumentImportTaskMqConstants.PARSE_TOPIC);
subscription.setConsumerGroup(DocumentImportTaskMqConstants.PARSE_GROUP);
subscription.setShardCount(resolveShardCount());
return subscription;
}
@Override
public void handle(List<MQMessage> messages) {
LOG.info("文档解析消费者收到消息批次: count={}", messages == null ? 0 : messages.size());
for (MQMessage message : messages) {
DocumentImportTaskMessage event = JSON.parseObject(message.getBody(), DocumentImportTaskMessage.class);
if (event == null || event.getTaskId() == null) {
LOG.warn("文档解析消费者跳过非法消息: streamMessageId={}, messageId={}",
message == null ? null : message.getStreamMessageId(),
message == null ? null : message.getMessageId());
continue;
}
LOG.info("文档解析消费者开始处理消息: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId());
try {
appService.handleParseTask(event.getTaskId());
LOG.info("文档解析消费者处理完成: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId());
} catch (Exception exception) {
LOG.error("文档解析消费者处理失败: taskId={}, messageId={}, streamMessageId={}",
event.getTaskId(), message.getMessageId(), message.getStreamMessageId(), exception);
throw exception;
}
}
}
/**
* 解析消费者需覆盖生产端的所有分片,避免消息落入未订阅分片。
*
* @return 当前 Redis Stream 分片数
*/
private int resolveShardCount() {
return Math.max(mqProperties.getRedis().getChatPersistShardCount(), 1);
}
}

View File

@@ -0,0 +1,52 @@
package tech.easyflow.ai.documentimport.task;
import com.alibaba.fastjson2.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import tech.easyflow.common.mq.core.MQMessage;
import tech.easyflow.common.mq.core.MQProducer;
import java.math.BigInteger;
import java.util.Date;
/**
* 文档解析任务消息生产者。
*
* @author Codex
* @since 2026-04-14
*/
@Service
public class DocumentImportParseTaskProducer {
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportParseTaskProducer.class);
private final MQProducer mqProducer;
public DocumentImportParseTaskProducer(MQProducer mqProducer) {
this.mqProducer = mqProducer;
}
/**
* 发送解析任务消息。
*
* @param taskId 任务 ID
*/
public void send(BigInteger taskId) {
DocumentImportTaskMessage event = new DocumentImportTaskMessage();
event.setTaskId(taskId);
event.setOccurredAt(new Date());
MQMessage message = new MQMessage();
message.setMessageId("parse-" + taskId);
message.setTopic(DocumentImportTaskMqConstants.PARSE_TOPIC);
message.setKey(String.valueOf(taskId));
message.setCreatedAt(event.getOccurredAt());
message.setBody(JSON.toJSONString(event));
LOG.info("准备投递文档解析 MQ 消息: topic={}, taskId={}, messageId={}",
message.getTopic(), taskId, message.getMessageId());
String recordId = mqProducer.send(message);
LOG.info("文档解析 MQ 消息投递完成: topic={}, taskId={}, messageId={}, recordId={}",
message.getTopic(), taskId, message.getMessageId(), recordId);
}
}

View File

@@ -0,0 +1,33 @@
package tech.easyflow.ai.documentimport.task;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Date;
/**
* 文档导入任务消息。
*
* @author Codex
* @since 2026-04-14
*/
public class DocumentImportTaskMessage implements Serializable {
private BigInteger taskId;
private Date occurredAt;
public BigInteger getTaskId() {
return taskId;
}
public void setTaskId(BigInteger taskId) {
this.taskId = taskId;
}
public Date getOccurredAt() {
return occurredAt;
}
public void setOccurredAt(Date occurredAt) {
this.occurredAt = occurredAt;
}
}

View File

@@ -0,0 +1,18 @@
package tech.easyflow.ai.documentimport.task;
/**
* 文档导入任务 MQ 常量。
*
* @author Codex
* @since 2026-04-14
*/
public final class DocumentImportTaskMqConstants {
private DocumentImportTaskMqConstants() {
}
public static final String PARSE_TOPIC = "knowledge-document-parse";
public static final String PARSE_GROUP = "knowledge-document-parse-group";
public static final String INDEX_TOPIC = "knowledge-document-index";
public static final String INDEX_GROUP = "knowledge-document-index-group";
}

View File

@@ -0,0 +1,156 @@
package tech.easyflow.ai.documentimport.task;
import org.springframework.http.MediaType;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import tech.easyflow.ai.entity.Document;
import tech.easyflow.ai.mapper.DocumentMapper;
import tech.easyflow.common.web.exceptions.BusinessException;
import javax.annotation.Resource;
import java.math.BigInteger;
import java.time.Duration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* 知识库文档任务状态 SSE 推送服务。
*
* @author Codex
* @since 2026-04-15
*/
@Service
public class DocumentImportTaskStatusStreamService {
private static final long SSE_TIMEOUT_MS = Duration.ofMinutes(30).toMillis();
private final Map<String, Set<SseEmitter>> knowledgeEmitters = new ConcurrentHashMap<String, Set<SseEmitter>>();
@Resource
private DocumentMapper documentMapper;
@Resource(name = "sseThreadPool")
private ThreadPoolTaskExecutor sseThreadPool;
/**
* 订阅知识库文档任务状态流。
*
* @param knowledgeId 知识库 ID
* @return SSE 连接
*/
public SseEmitter subscribe(BigInteger knowledgeId) {
if (knowledgeId == null) {
throw new BusinessException("知识库id不能为空");
}
String topicKey = toTopicKey(knowledgeId);
SseEmitter emitter = new SseEmitter(SSE_TIMEOUT_MS);
knowledgeEmitters.computeIfAbsent(topicKey, key -> ConcurrentHashMap.newKeySet()).add(emitter);
emitter.onCompletion(() -> removeEmitter(topicKey, emitter));
emitter.onTimeout(() -> {
removeEmitter(topicKey, emitter);
emitter.complete();
});
emitter.onError(error -> removeEmitter(topicKey, emitter));
sendAsync(topicKey, emitter, "connected", buildConnectedPayload(knowledgeId));
return emitter;
}
/**
* 在事务提交后推送文档任务状态变更。
*
* @param documentId 文档 ID
*/
public void publishAfterCommit(BigInteger documentId) {
if (documentId == null) {
return;
}
Runnable publishAction = () -> publishNow(documentId);
if (TransactionSynchronizationManager.isSynchronizationActive()
&& TransactionSynchronizationManager.isActualTransactionActive()) {
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
@Override
public void afterCommit() {
publishAction.run();
}
});
return;
}
publishAction.run();
}
private void publishNow(BigInteger documentId) {
Document document = documentMapper.selectOneById(documentId);
if (document == null || document.getCollectionId() == null) {
return;
}
String topicKey = toTopicKey(document.getCollectionId());
Set<SseEmitter> emitters = knowledgeEmitters.get(topicKey);
if (emitters == null || emitters.isEmpty()) {
return;
}
Map<String, Object> payload = buildDocumentPayload(document);
for (SseEmitter emitter : emitters) {
sendAsync(topicKey, emitter, "document-status", payload);
}
}
private Map<String, Object> buildConnectedPayload(BigInteger knowledgeId) {
Map<String, Object> payload = new LinkedHashMap<String, Object>();
payload.put("knowledgeId", knowledgeId.toString());
payload.put("type", "connected");
return payload;
}
private Map<String, Object> buildDocumentPayload(Document document) {
Map<String, Object> payload = new LinkedHashMap<String, Object>();
payload.put("type", "document-status");
payload.put("knowledgeId", document.getCollectionId() == null ? null : document.getCollectionId().toString());
payload.put("documentId", document.getId() == null ? null : document.getId().toString());
payload.put("processStatus", document.getProcessStatus());
payload.put("progressPercent", document.getProgressPercent());
payload.put("totalChunks", document.getTotalChunks());
payload.put("completedChunks", document.getCompletedChunks());
payload.put("failedChunks", document.getFailedChunks());
payload.put("lastTaskError", document.getLastTaskError());
payload.put("taskModifiedAt", document.getTaskModifiedAt());
return payload;
}
private void sendAsync(String topicKey, SseEmitter emitter, String eventName, Map<String, Object> payload) {
sseThreadPool.execute(() -> {
try {
emitter.send(
SseEmitter.event()
.name(eventName)
.data(payload, MediaType.APPLICATION_JSON)
);
} catch (Exception e) {
removeEmitter(topicKey, emitter);
try {
emitter.completeWithError(e);
} catch (Exception ignored) {
}
}
});
}
private void removeEmitter(String topicKey, SseEmitter emitter) {
Set<SseEmitter> emitters = knowledgeEmitters.get(topicKey);
if (emitters == null) {
return;
}
emitters.remove(emitter);
if (emitters.isEmpty()) {
knowledgeEmitters.remove(topicKey);
}
}
private String toTopicKey(BigInteger knowledgeId) {
return String.valueOf(knowledgeId);
}
}

View File

@@ -55,4 +55,16 @@ public class Document extends DocumentBase {
public void setOverlapSize(int overlapSize) {
this.overlapSize = overlapSize;
}
/**
* 获取列表展示时的优先分块数。
*
* @return 分块数
*/
public long getDisplayChunkCount() {
if (getTotalChunks() != null && getTotalChunks() > 0) {
return getTotalChunks();
}
return chunkCount == null ? 0L : chunkCount.longValue();
}
}

View File

@@ -0,0 +1,182 @@
package tech.easyflow.ai.entity;
import com.mybatisflex.annotation.Column;
import com.mybatisflex.annotation.Id;
import com.mybatisflex.annotation.KeyType;
import com.mybatisflex.annotation.Table;
import com.mybatisflex.core.handler.FastjsonTypeHandler;
import tech.easyflow.common.entity.DateEntity;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* 知识库文档导入任务。
*
* @author Codex
* @since 2026-04-14
*/
@Table(value = "tb_document_import_task", comment = "知识库文档导入任务")
public class DocumentImportTask extends DateEntity implements Serializable {
@Id(keyType = KeyType.Generator, value = "snowFlakeId")
private BigInteger id;
@Column(comment = "文档ID")
private BigInteger documentId;
@Column(comment = "知识库ID")
private BigInteger knowledgeId;
@Column(comment = "任务阶段")
private String phase;
@Column(comment = "任务状态")
private String status;
@Column(comment = "底层任务ID")
private String providerTaskId;
@Column(typeHandler = FastjsonTypeHandler.class, comment = "任务载荷")
private Map<String, Object> payloadJson;
@Column(comment = "错误摘要")
private String errorSummary;
@Column(comment = "开始时间")
private Date startedAt;
@Column(comment = "结束时间")
private Date finishedAt;
@Column(comment = "创建时间")
private Date created;
@Column(comment = "创建人")
private BigInteger createdBy;
@Column(comment = "修改时间")
private Date modified;
@Column(comment = "修改人")
private BigInteger modifiedBy;
public BigInteger getId() {
return id;
}
public void setId(BigInteger id) {
this.id = id;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getPhase() {
return phase;
}
public void setPhase(String phase) {
this.phase = phase;
}
public String getStatus() {
return status;
}
public void setStatus(String status) {
this.status = status;
}
public String getProviderTaskId() {
return providerTaskId;
}
public void setProviderTaskId(String providerTaskId) {
this.providerTaskId = providerTaskId;
}
public Map<String, Object> getPayloadJson() {
return payloadJson == null ? new LinkedHashMap<String, Object>() : payloadJson;
}
public void setPayloadJson(Map<String, Object> payloadJson) {
this.payloadJson = payloadJson == null ? new LinkedHashMap<String, Object>() : payloadJson;
}
public String getErrorSummary() {
return errorSummary;
}
public void setErrorSummary(String errorSummary) {
this.errorSummary = errorSummary;
}
public Date getStartedAt() {
return startedAt;
}
public void setStartedAt(Date startedAt) {
this.startedAt = startedAt;
}
public Date getFinishedAt() {
return finishedAt;
}
public void setFinishedAt(Date finishedAt) {
this.finishedAt = finishedAt;
}
@Override
public Date getCreated() {
return created;
}
@Override
public void setCreated(Date created) {
this.created = created;
}
public BigInteger getCreatedBy() {
return createdBy;
}
public void setCreatedBy(BigInteger createdBy) {
this.createdBy = createdBy;
}
@Override
public Date getModified() {
return modified;
}
@Override
public void setModified(Date modified) {
this.modified = modified;
}
public BigInteger getModifiedBy() {
return modifiedBy;
}
public void setModifiedBy(BigInteger modifiedBy) {
this.modifiedBy = modifiedBy;
}
}

View File

@@ -72,6 +72,48 @@ public class DocumentBase extends DateEntity implements Serializable {
@Column(typeHandler = FastjsonTypeHandler.class, comment = "其他配置项")
private Map<String, Object> options;
/**
* 处理状态
*/
@Column(comment = "处理状态")
private String processStatus;
/**
* 总分块数
*/
@Column(comment = "总分块数")
private Integer totalChunks;
/**
* 已完成分块数
*/
@Column(comment = "已完成分块数")
private Integer completedChunks;
/**
* 失败分块数
*/
@Column(comment = "失败分块数")
private Integer failedChunks;
/**
* 处理进度百分比
*/
@Column(comment = "处理进度百分比")
private Integer progressPercent;
/**
* 最近任务错误摘要
*/
@Column(comment = "最近任务错误摘要")
private String lastTaskError;
/**
* 任务状态更新时间
*/
@Column(comment = "任务状态更新时间")
private Date taskModifiedAt;
/**
* 创建时间
*/
@@ -176,6 +218,62 @@ public class DocumentBase extends DateEntity implements Serializable {
this.options = options;
}
public String getProcessStatus() {
return processStatus;
}
public void setProcessStatus(String processStatus) {
this.processStatus = processStatus;
}
public Integer getTotalChunks() {
return totalChunks;
}
public void setTotalChunks(Integer totalChunks) {
this.totalChunks = totalChunks;
}
public Integer getCompletedChunks() {
return completedChunks;
}
public void setCompletedChunks(Integer completedChunks) {
this.completedChunks = completedChunks;
}
public Integer getFailedChunks() {
return failedChunks;
}
public void setFailedChunks(Integer failedChunks) {
this.failedChunks = failedChunks;
}
public Integer getProgressPercent() {
return progressPercent;
}
public void setProgressPercent(Integer progressPercent) {
this.progressPercent = progressPercent;
}
public String getLastTaskError() {
return lastTaskError;
}
public void setLastTaskError(String lastTaskError) {
this.lastTaskError = lastTaskError;
}
public Date getTaskModifiedAt() {
return taskModifiedAt;
}
public void setTaskModifiedAt(Date taskModifiedAt) {
this.taskModifiedAt = taskModifiedAt;
}
public Date getCreated() {
return created;
}

View File

@@ -0,0 +1,20 @@
package tech.easyflow.ai.enums;
/**
* 文档导入任务阶段。
*
* @author Codex
* @since 2026-04-14
*/
public enum DocumentImportTaskPhase {
/**
* 文档解析阶段。
*/
PARSE,
/**
* 向量化阶段。
*/
INDEX
}

View File

@@ -0,0 +1,30 @@
package tech.easyflow.ai.enums;
/**
* 文档导入任务状态。
*
* @author Codex
* @since 2026-04-14
*/
public enum DocumentImportTaskStatus {
/**
* 已创建,等待执行。
*/
PENDING,
/**
* 正在执行。
*/
RUNNING,
/**
* 执行失败。
*/
FAILED,
/**
* 执行完成。
*/
COMPLETED
}

View File

@@ -0,0 +1,59 @@
package tech.easyflow.ai.enums;
/**
* 文档处理状态。
*
* @author Codex
* @since 2026-04-14
*/
public enum DocumentProcessStatus {
/**
* 已上传,尚未进入异步处理。
*/
UPLOADED,
/**
* 解析中。
*/
PARSING,
/**
* 解析失败。
*/
PARSE_FAILED,
/**
* 可继续配置分块。
*/
READY_FOR_SEGMENT,
/**
* 已确认分块,可开始向量化。
*/
READY_FOR_INDEX,
/**
* 向量化处理中。
*/
INDEXING,
/**
* 向量化失败。
*/
INDEX_FAILED,
/**
* 全流程完成。
*/
COMPLETED;
/**
* 判断当前状态是否属于运行中状态。
*
* @return 是否运行中
*/
public boolean isProcessing() {
return this == PARSING || this == INDEXING;
}
}

View File

@@ -0,0 +1,13 @@
package tech.easyflow.ai.mapper;
import com.mybatisflex.core.BaseMapper;
import tech.easyflow.ai.entity.DocumentImportTask;
/**
* 文档导入任务映射层。
*
* @author Codex
* @since 2026-04-14
*/
public interface DocumentImportTaskMapper extends BaseMapper<DocumentImportTask> {
}

View File

@@ -9,6 +9,7 @@ public class KnowledgeRetrievalRequest {
private BigInteger knowledgeId;
private String query;
private Integer limit;
private Double minSimilarity;
private RetrievalMode retrievalMode = RetrievalMode.HYBRID;
private String callerType;
private String callerId;
@@ -37,6 +38,24 @@ public class KnowledgeRetrievalRequest {
this.limit = limit;
}
/**
* 返回检索时使用的最小相似度阈值。
*
* @return 最小相似度阈值
*/
public Double getMinSimilarity() {
return minSimilarity;
}
/**
* 设置检索时使用的最小相似度阈值。
*
* @param minSimilarity 最小相似度阈值
*/
public void setMinSimilarity(Double minSimilarity) {
this.minSimilarity = minSimilarity;
}
public RetrievalMode getRetrievalMode() {
return retrievalMode;
}

View File

@@ -0,0 +1,13 @@
package tech.easyflow.ai.service;
import com.mybatisflex.core.service.IService;
import tech.easyflow.ai.entity.DocumentImportTask;
/**
* 文档导入任务服务。
*
* @author Codex
* @since 2026-04-14
*/
public interface DocumentImportTaskService extends IService<DocumentImportTask> {
}

View File

@@ -32,4 +32,16 @@ public interface DocumentService extends IService<Document> {
Result<DocumentImportDtos.PreviewResponse> previewImport(DocumentImportDtos.PreviewRequest request);
Result<DocumentImportDtos.CommitResponse> commitImport(DocumentImportDtos.CommitRequest request);
Result<DocumentImportDtos.TaskCreateResponse> createImportTask(DocumentImportDtos.TaskCreateRequest request);
Result<DocumentImportDtos.TaskDetailResponse> getImportTaskDetail(BigInteger taskId);
Result<DocumentImportDtos.PreviewResponse> previewImportTask(DocumentImportDtos.PreviewRequest request);
Result<DocumentImportDtos.TaskStartIndexResponse> startIndexTask(DocumentImportDtos.TaskStartIndexRequest request);
Result<DocumentImportDtos.TaskStartIndexResponse> retryParseTask(DocumentImportDtos.TaskRetryRequest request);
Result<DocumentImportDtos.TaskStartIndexResponse> retryIndexTask(DocumentImportDtos.TaskRetryRequest request);
}

View File

@@ -30,9 +30,11 @@ import tech.easyflow.ai.entity.DocumentChunk;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.FaqItem;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.enums.DocumentProcessStatus;
import tech.easyflow.ai.enums.PublishStatus;
import tech.easyflow.ai.mapper.DocumentChunkMapper;
import tech.easyflow.ai.mapper.DocumentCollectionMapper;
import tech.easyflow.ai.mapper.DocumentMapper;
import tech.easyflow.ai.mapper.FaqItemMapper;
import tech.easyflow.ai.rag.KnowledgeRetrievalRequest;
import tech.easyflow.ai.service.DocumentCollectionService;
@@ -71,6 +73,8 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
private static final Logger LOG = LoggerFactory.getLogger(DocumentCollectionServiceImpl.class);
private static final int MAX_FAQ_IMAGES_IN_PROMPT = 3;
private static final int INTERNAL_RECALL_MULTIPLIER = 5;
private static final int MAX_INTERNAL_RECALL_LIMIT = 100;
@Resource
private ModelService llmService;
@@ -81,6 +85,9 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
@Autowired
private DocumentChunkMapper documentChunkMapper;
@Autowired
private DocumentMapper documentMapper;
@Autowired
private FaqItemMapper faqItemMapper;
@@ -111,24 +118,27 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
throw new BusinessException("知识库不存在");
}
int docRecallMaxNum = readIntegerOption(documentCollection, KEY_DOC_RECALL_MAX_NUM, 5);
float minSimilarity = readFloatOption(documentCollection, KEY_SIMILARITY_THRESHOLD, 0.6F);
int docRecallMaxNum = resolveDocRecallMaxNum(request, documentCollection);
int internalRecallLimit = resolveInternalRecallLimit(docRecallMaxNum);
float minSimilarity = resolveMinSimilarity(request, documentCollection);
RagQuery ragQuery = new RagQuery();
ragQuery.setQuery(keyword);
ragQuery.setRetrievalMode(retrievalMode);
ragQuery.setTopK(docRecallMaxNum);
ragQuery.setTopK(internalRecallLimit);
ragQuery.setMinScore((double) minSimilarity);
RagRetrievalExecutor retrievalExecutor = new RagRetrievalExecutor(
buildVectorRetriever(documentCollection, docRecallMaxNum, retrievalMode == RetrievalMode.VECTOR ? minSimilarity : null),
buildKeywordRetriever(documentCollection, docRecallMaxNum),
buildVectorRetriever(documentCollection, internalRecallLimit, retrievalMode == RetrievalMode.VECTOR ? minSimilarity : null),
buildKeywordRetriever(documentCollection, internalRecallLimit),
new RrfFusionStrategy()
);
RagRetrievalResult retrievalResult = retrievalExecutor.retrieve(ragQuery);
List<Document> searchDocuments = toDocuments(retrievalResult.getHits());
fillSearchContent(documentCollection, searchDocuments);
List<Document> searchDocuments = prepareSearchDocuments(
documentCollection,
toDocuments(retrievalResult.getHits())
);
if (searchDocuments.isEmpty()) {
return Collections.emptyList();
}
@@ -138,7 +148,10 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
if (rerankModel != null) {
try {
RagRetrievalResult rerankResult = retrievalExecutor.rerank(keyword, toRagHits(searchDocuments), rerankModel, docRecallMaxNum);
searchDocuments = toDocuments(rerankResult.getHits());
searchDocuments = prepareSearchDocuments(
documentCollection,
toDocuments(rerankResult.getHits())
);
reranked = true;
} catch (RerankException e) {
LOG.warn("Rerank failed for collectionId={}, modelId={}, fallback to retrieved results. message={}",
@@ -320,6 +333,84 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
return !reranked && retrievalMode == RetrievalMode.VECTOR;
}
/**
* 解析本次查询使用的召回上限,优先采用请求参数,其次回退到知识库默认配置。
*
* @param request 查询请求
* @param documentCollection 知识库实体
* @return 规范化后的召回上限
*/
private int resolveDocRecallMaxNum(KnowledgeRetrievalRequest request, DocumentCollection documentCollection) {
Integer requestLimit = request == null ? null : request.getLimit();
if (requestLimit != null) {
return Math.max(requestLimit, 1);
}
return readIntegerOption(documentCollection, KEY_DOC_RECALL_MAX_NUM, 5);
}
/**
* 解析本次查询使用的最小相似度阈值,优先采用请求参数,其次回退到知识库默认配置。
*
* @param request 查询请求
* @param documentCollection 知识库实体
* @return 规范化后的最小相似度
*/
private float resolveMinSimilarity(KnowledgeRetrievalRequest request, DocumentCollection documentCollection) {
Double requestMinSimilarity = request == null ? null : request.getMinSimilarity();
if (requestMinSimilarity != null) {
double normalizedValue = Math.max(0D, Math.min(requestMinSimilarity, 1D));
return (float) normalizedValue;
}
return readFloatOption(documentCollection, KEY_SIMILARITY_THRESHOLD, 0.6F);
}
/**
* 计算内部候选召回上限。
*
* @param docRecallMaxNum 业务召回上限
* @return 内部候选集上限
*/
private int resolveInternalRecallLimit(int docRecallMaxNum) {
int normalizedLimit = Math.max(docRecallMaxNum, 1);
return Math.min(normalizedLimit * INTERNAL_RECALL_MULTIPLIER, MAX_INTERNAL_RECALL_LIMIT);
}
/**
* 在重排前过滤掉未完成文档对应的 chunk 命中,避免进行中的文档占用正式召回名额。
*
* @param documentCollection 知识库
* @param searchDocuments 当前召回结果
* @return 仅保留完成态文档命中的结果
*/
private List<Document> prepareSearchDocuments(DocumentCollection documentCollection, List<Document> searchDocuments) {
if (searchDocuments == null || searchDocuments.isEmpty()) {
return Collections.emptyList();
}
if (documentCollection == null) {
return searchDocuments;
}
if (documentCollection.isFaqCollection()) {
fillSearchContent(documentCollection, searchDocuments);
return searchDocuments;
}
DocumentHitSnapshot hitSnapshot = loadDocumentHitSnapshot(documentCollection, searchDocuments);
if (hitSnapshot.isEmpty()) {
return Collections.emptyList();
}
return searchDocuments.stream()
.filter(Objects::nonNull)
.filter(item -> {
String content = hitSnapshot.findChunkContent(item.getId());
if (!StringUtil.hasText(content)) {
return false;
}
item.setContent(content);
return true;
})
.collect(Collectors.toList());
}
@Override
public DocumentCollection getDetail(String idOrAlias) {
DocumentCollection knowledge = null;
@@ -418,18 +509,93 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
return;
}
QueryWrapper queryWrapper = QueryWrapper.create();
queryWrapper.in(DocumentChunk::getId, ids);
queryWrapper.eq(DocumentChunk::getDocumentCollectionId, documentCollection.getId());
Map<String, DocumentChunk> chunkMap = documentChunkMapper.selectListByQuery(queryWrapper).stream()
.collect(Collectors.toMap(item -> item.getId().toString(), item -> item, (a, b) -> a));
searchDocuments.removeIf(item -> !chunkMap.containsKey(String.valueOf(item.getId())));
DocumentHitSnapshot hitSnapshot = loadDocumentHitSnapshot(documentCollection, searchDocuments);
searchDocuments.forEach(item -> {
DocumentChunk documentChunk = chunkMap.get(String.valueOf(item.getId()));
if (documentChunk != null && !StringUtil.noText(documentChunk.getContent())) {
item.setContent(documentChunk.getContent());
}
item.setContent(hitSnapshot.findChunkContent(item.getId()));
});
searchDocuments.removeIf(item -> !StringUtil.hasText(item.getContent()));
}
/**
* 批量加载命中 chunk 及其完成态父文档,供过滤和内容填充复用。
*
* @param documentCollection 知识库
* @param searchDocuments 检索命中
* @return 命中快照
*/
private DocumentHitSnapshot loadDocumentHitSnapshot(DocumentCollection documentCollection, List<Document> searchDocuments) {
if (documentCollection == null || searchDocuments == null || searchDocuments.isEmpty()) {
return DocumentHitSnapshot.empty();
}
List<Serializable> chunkIds = searchDocuments.stream()
.map(Document::getId)
.filter(Objects::nonNull)
.map(item -> (Serializable) item)
.collect(Collectors.toList());
if (chunkIds.isEmpty()) {
return DocumentHitSnapshot.empty();
}
QueryWrapper chunkWrapper = QueryWrapper.create();
chunkWrapper.in(DocumentChunk::getId, chunkIds);
chunkWrapper.eq(DocumentChunk::getDocumentCollectionId, documentCollection.getId());
Map<String, DocumentChunk> chunkMap = documentChunkMapper.selectListByQuery(chunkWrapper).stream()
.collect(Collectors.toMap(item -> item.getId().toString(), item -> item, (a, b) -> a));
if (chunkMap.isEmpty()) {
return DocumentHitSnapshot.empty();
}
List<Serializable> documentIds = chunkMap.values().stream()
.map(DocumentChunk::getDocumentId)
.filter(Objects::nonNull)
.distinct()
.map(item -> (Serializable) item)
.collect(Collectors.toList());
if (documentIds.isEmpty()) {
return DocumentHitSnapshot.empty();
}
QueryWrapper documentWrapper = QueryWrapper.create();
documentWrapper.in(tech.easyflow.ai.entity.Document::getId, documentIds);
documentWrapper.eq(tech.easyflow.ai.entity.Document::getCollectionId, documentCollection.getId());
documentWrapper.eq(tech.easyflow.ai.entity.Document::getProcessStatus, DocumentProcessStatus.COMPLETED.name());
Map<String, tech.easyflow.ai.entity.Document> documentMap = documentMapper.selectListByQuery(documentWrapper).stream()
.collect(Collectors.toMap(item -> item.getId().toString(), item -> item, (a, b) -> a));
return new DocumentHitSnapshot(chunkMap, documentMap);
}
/**
* 文档检索命中的批量快照,避免过滤和填充阶段重复查询。
*/
private static class DocumentHitSnapshot {
private final Map<String, DocumentChunk> chunkMap;
private final Map<String, tech.easyflow.ai.entity.Document> documentMap;
private DocumentHitSnapshot(Map<String, DocumentChunk> chunkMap,
Map<String, tech.easyflow.ai.entity.Document> documentMap) {
this.chunkMap = chunkMap == null ? Collections.emptyMap() : chunkMap;
this.documentMap = documentMap == null ? Collections.emptyMap() : documentMap;
}
private static DocumentHitSnapshot empty() {
return new DocumentHitSnapshot(Collections.emptyMap(), Collections.emptyMap());
}
private boolean isEmpty() {
return chunkMap.isEmpty() || documentMap.isEmpty();
}
private String findChunkContent(Object chunkId) {
DocumentChunk documentChunk = chunkMap.get(String.valueOf(chunkId));
if (documentChunk == null || documentChunk.getDocumentId() == null) {
return null;
}
if (!documentMap.containsKey(String.valueOf(documentChunk.getDocumentId()))) {
return null;
}
return StringUtil.noText(documentChunk.getContent()) ? null : documentChunk.getContent();
}
}
private String buildFaqPromptContent(FaqItem faqItem, List<Map<String, String>> images) {

View File

@@ -0,0 +1,18 @@
package tech.easyflow.ai.service.impl;
import com.mybatisflex.spring.service.impl.ServiceImpl;
import org.springframework.stereotype.Service;
import tech.easyflow.ai.entity.DocumentImportTask;
import tech.easyflow.ai.mapper.DocumentImportTaskMapper;
import tech.easyflow.ai.service.DocumentImportTaskService;
/**
* 文档导入任务服务实现。
*
* @author Codex
* @since 2026-04-14
*/
@Service
public class DocumentImportTaskServiceImpl extends ServiceImpl<DocumentImportTaskMapper, DocumentImportTask>
implements DocumentImportTaskService {
}

View File

@@ -34,7 +34,9 @@ import tech.easyflow.ai.config.SearcherFactory;
import tech.easyflow.ai.documentimport.DocumentImportDtos;
import tech.easyflow.ai.documentimport.DocumentImportKeys;
import tech.easyflow.ai.documentimport.DocumentImportPreviewService;
import tech.easyflow.ai.documentimport.task.KnowledgeDocumentImportTaskAppService;
import tech.easyflow.ai.entity.*;
import tech.easyflow.ai.enums.DocumentProcessStatus;
import tech.easyflow.ai.mapper.DocumentChunkMapper;
import tech.easyflow.ai.mapper.DocumentMapper;
import tech.easyflow.ai.service.DocumentChunkService;
@@ -69,6 +71,7 @@ import static tech.easyflow.ai.entity.table.DocumentTableDef.DOCUMENT;
@Service("AiService")
public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> implements DocumentService {
protected Logger Log = LoggerFactory.getLogger(DocumentServiceImpl.class);
private static final String SOURCE_RANGES_KEY = "sourceRanges";
@Resource
private DocumentMapper documentMapper;
@@ -97,6 +100,9 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
@Autowired
private DocumentImportPreviewService documentImportPreviewService;
@Autowired
private KnowledgeDocumentImportTaskAppService importTaskAppService;
@Override
public Page<Document> getDocumentList(String knowledgeId, int pageSize, int pageNum, String fileName) {
QueryWrapper queryWrapper=QueryWrapper.create()
@@ -130,6 +136,13 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
// 查询该文档对应哪些分割的字段,先删除
QueryWrapper queryWrapperDocument = QueryWrapper.create().eq(Document::getId, id);
Document oneByQuery = documentMapper.selectOneByQuery(queryWrapperDocument);
if (oneByQuery == null) {
return false;
}
if (DocumentProcessStatus.PARSING.name().equals(oneByQuery.getProcessStatus())
|| DocumentProcessStatus.INDEXING.name().equals(oneByQuery.getProcessStatus())) {
throw new BusinessException("文档处理中,暂不允许删除");
}
DocumentCollection knowledge = knowledgeService.getById(oneByQuery.getCollectionId());
if (knowledge == null) {
return false;
@@ -209,12 +222,18 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
aiDocument.setDocumentPath(filePath);
aiDocument.setCreated(new Date());
aiDocument.setModifiedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
aiDocument.setModified(new Date());
aiDocument.setContent(document.getContent());
aiDocument.setChunkSize(documentCollectionSplitParams.getChunkSize());
aiDocument.setOverlapSize(documentCollectionSplitParams.getOverlapSize());
aiDocument.setTitle(fileOriginName);
Map<String, Object> res = new HashMap<>();
aiDocument.setModified(new Date());
aiDocument.setContent(document.getContent());
aiDocument.setChunkSize(documentCollectionSplitParams.getChunkSize());
aiDocument.setOverlapSize(documentCollectionSplitParams.getOverlapSize());
aiDocument.setTitle(fileOriginName);
aiDocument.setProcessStatus(DocumentProcessStatus.COMPLETED.name());
aiDocument.setTotalChunks(previewList.size());
aiDocument.setCompletedChunks(previewList.size());
aiDocument.setFailedChunks(0);
aiDocument.setProgressPercent(100);
aiDocument.setTaskModifiedAt(new Date());
Map<String, Object> res = new HashMap<>();
List<DocumentChunk> documentChunks = null;
String operation = documentCollectionSplitParams.getOperation();
@@ -334,10 +353,11 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
item.setPreviewSessionId(sessionId);
item.setFilePath(file.getFilePath());
item.setFileName(file.getFileName());
item.setNormalizedContent(session.getAnalysis() == null ? null : session.getAnalysis().getNormalizedContent());
item.setStrategyCode(session.getStrategyConfig().getStrategyCode());
item.setStrategyLabel(ragIngestionService.toStrategyLabel(session.getStrategyConfig().getStrategyCode()));
item.setAnalysis(session.getAnalysis());
item.setChunks(session.getPreviewChunks());
item.setChunks(toPreviewChunkResults(session.getPreviewChunks()));
item.setTotalChunks(session.getPreviewChunks().size());
item.setTotalWarnings(countWarnings(session.getPreviewChunks()));
items.add(item);
@@ -398,6 +418,12 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
document.setModified(new Date());
document.setCreatedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
document.setModifiedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
document.setProcessStatus(DocumentProcessStatus.COMPLETED.name());
document.setTotalChunks(session.getDocumentChunks().size());
document.setCompletedChunks(session.getDocumentChunks().size());
document.setFailedChunks(0);
document.setProgressPercent(100);
document.setTaskModifiedAt(new Date());
for (DocumentChunk chunk : session.getDocumentChunks()) {
chunk.setDocumentId(document.getId());
chunk.setDocumentCollectionId(document.getCollectionId());
@@ -430,6 +456,7 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
DocumentImportDtos.PreviewSession session = new DocumentImportDtos.PreviewSession();
session.setKnowledgeId(knowledge.getId());
session.setDocumentId(document.getId());
session.setFilePath(fileRequest.getFilePath());
session.setFileName(fileRequest.getFileName());
session.setSourceFormat(analysis.getSourceFormat());
@@ -656,6 +683,55 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
return total;
}
private List<DocumentImportDtos.PreviewChunkResult> toPreviewChunkResults(List<RagChunk> chunks) {
List<DocumentImportDtos.PreviewChunkResult> result = new ArrayList<>();
if (chunks == null) {
return result;
}
for (RagChunk chunk : chunks) {
DocumentImportDtos.PreviewChunkResult item = new DocumentImportDtos.PreviewChunkResult();
item.setAnswer(chunk.getAnswer());
item.setCharCount(chunk.getCharCount());
item.setChunkId(chunk.getChunkId());
item.setChunkType(chunk.getChunkType());
item.setContent(chunk.getContent());
item.setHeadingPath(chunk.getHeadingPath() == null ? new ArrayList<>() : new ArrayList<>(chunk.getHeadingPath()));
item.setPartNo(chunk.getPartNo());
item.setPartTotal(chunk.getPartTotal());
item.setQuestion(chunk.getQuestion());
item.setSourceLabel(chunk.getSourceLabel());
item.setTokenEstimate(chunk.getTokenEstimate());
item.setWarnings(chunk.getWarnings() == null ? new ArrayList<>() : new ArrayList<>(chunk.getWarnings()));
item.setSourceRanges(copySourceRanges(chunk));
result.add(item);
}
return result;
}
@SuppressWarnings("unchecked")
private List<DocumentImportDtos.PreviewSourceRange> copySourceRanges(RagChunk chunk) {
List<DocumentImportDtos.PreviewSourceRange> result = new ArrayList<>();
if (chunk == null || chunk.getOptions() == null) {
return result;
}
Object rawRanges = chunk.getOptions().get(SOURCE_RANGES_KEY);
if (!(rawRanges instanceof List<?> rangeList)) {
return result;
}
for (Object item : rangeList) {
if (!(item instanceof Map<?, ?> rangeMap)) {
continue;
}
DocumentImportDtos.PreviewSourceRange range = new DocumentImportDtos.PreviewSourceRange();
range.setStart(asInteger(rangeMap.get("start"), null));
range.setEnd(asInteger(rangeMap.get("end"), null));
if (range.getStart() != null && range.getEnd() != null) {
result.add(range);
}
}
return result;
}
private StoreExecutionContext prepareStoreContext(Document entity) {
DocumentCollection knowledge = knowledgeService.getById(entity.getCollectionId());
if (knowledge == null) {
@@ -882,4 +958,34 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
}
return null;
}
@Override
public Result<DocumentImportDtos.TaskCreateResponse> createImportTask(DocumentImportDtos.TaskCreateRequest request) {
return importTaskAppService.createImportTask(request);
}
@Override
public Result<DocumentImportDtos.TaskDetailResponse> getImportTaskDetail(BigInteger taskId) {
return importTaskAppService.getImportTaskDetail(taskId);
}
@Override
public Result<DocumentImportDtos.PreviewResponse> previewImportTask(DocumentImportDtos.PreviewRequest request) {
return importTaskAppService.previewImportTask(request);
}
@Override
public Result<DocumentImportDtos.TaskStartIndexResponse> startIndexTask(DocumentImportDtos.TaskStartIndexRequest request) {
return importTaskAppService.startIndexTask(request);
}
@Override
public Result<DocumentImportDtos.TaskStartIndexResponse> retryParseTask(DocumentImportDtos.TaskRetryRequest request) {
return importTaskAppService.retryParseTask(request);
}
@Override
public Result<DocumentImportDtos.TaskStartIndexResponse> retryIndexTask(DocumentImportDtos.TaskRetryRequest request) {
return importTaskAppService.retryIndexTask(request);
}
}

View File

@@ -37,6 +37,7 @@ public class KnowledgeSharePermissionServiceImpl implements KnowledgeSharePermis
"/public-api/knowledge-share/detail",
"/public-api/knowledge-share/document/page",
"/public-api/knowledge-share/document/download",
"/public-api/knowledge-share/document/import/task/detail",
"/public-api/knowledge-share/documentChunk/page",
"/public-api/knowledge-share/faq/page",
"/public-api/knowledge-share/faq/detail"
@@ -48,6 +49,11 @@ public class KnowledgeSharePermissionServiceImpl implements KnowledgeSharePermis
"/public-api/knowledge-share/document/import/analyze",
"/public-api/knowledge-share/document/import/preview",
"/public-api/knowledge-share/document/import/commit",
"/public-api/knowledge-share/document/import/task/create",
"/public-api/knowledge-share/document/import/task/preview",
"/public-api/knowledge-share/document/import/task/startIndex",
"/public-api/knowledge-share/document/import/task/retryParse",
"/public-api/knowledge-share/document/import/task/retryIndex",
"/public-api/knowledge-share/faq/save"
));
URI_SCOPE_MAPPING.put(KnowledgeShareActionScope.CONTENT_UPDATE.name(), List.of(

View File

@@ -0,0 +1,44 @@
package tech.easyflow.ai.vo;
import tech.easyflow.ai.entity.DocumentCollection;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* 知识库分享页详情视图。
*
* @author Codex
* @since 2026-04-15
*/
public class KnowledgeShareViewDetail implements Serializable {
/**
* 当前分享对应的知识库。
*/
private DocumentCollection knowledge;
/**
* 当前分享授权范围。
*/
private List<String> permissionScopes = new ArrayList<String>();
public DocumentCollection getKnowledge() {
return knowledge;
}
public void setKnowledge(DocumentCollection knowledge) {
this.knowledge = knowledge;
}
public List<String> getPermissionScopes() {
return permissionScopes;
}
public void setPermissionScopes(List<String> permissionScopes) {
this.permissionScopes = permissionScopes == null
? new ArrayList<String>()
: new ArrayList<String>(permissionScopes);
}
}

View File

@@ -0,0 +1,148 @@
package tech.easyflow.ai.documentimport.task;
import org.junit.Assert;
import org.junit.Test;
import tech.easyflow.ai.entity.DocumentImportTask;
import tech.easyflow.ai.enums.DocumentImportTaskStatus;
import tech.easyflow.ai.enums.DocumentProcessStatus;
import tech.easyflow.ai.mapper.DocumentMapper;
import tech.easyflow.ai.service.DocumentImportTaskService;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.math.BigInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* {@link KnowledgeDocumentImportTaskAppService} 回归测试。
*
* @author Codex
* @since 2026-04-15
*/
public class KnowledgeDocumentImportTaskAppServiceTest {
/**
* 验证向量化失败会按整文档失败语义重置进度,并刷新任务错误信息。
*
* @throws Exception 反射调用异常
*/
@Test
public void markIndexFailedShouldResetProgressAndPersistLatestError() throws Exception {
BigInteger documentId = BigInteger.valueOf(10);
BigInteger knowledgeId = BigInteger.valueOf(20);
tech.easyflow.ai.entity.Document persistedDocument = new tech.easyflow.ai.entity.Document();
persistedDocument.setId(documentId);
persistedDocument.setCollectionId(knowledgeId);
persistedDocument.setProcessStatus(DocumentProcessStatus.INDEXING.name());
persistedDocument.setTotalChunks(8);
persistedDocument.setCompletedChunks(5);
persistedDocument.setFailedChunks(1);
persistedDocument.setProgressPercent(63);
persistedDocument.setLastTaskError("旧错误");
AtomicReference<tech.easyflow.ai.entity.Document> updatedDocumentRef = new AtomicReference<tech.easyflow.ai.entity.Document>();
AtomicReference<DocumentImportTask> updatedTaskRef = new AtomicReference<DocumentImportTask>();
KnowledgeDocumentImportTaskAppService service = new KnowledgeDocumentImportTaskAppService();
setField(service, "documentMapper", mockDocumentMapper(persistedDocument, updatedDocumentRef));
setField(service, "documentImportTaskService", mockDocumentImportTaskService(updatedTaskRef));
setField(service, "documentImportTaskStatusStreamService", new NoopTaskStatusStreamService());
DocumentImportTask task = new DocumentImportTask();
task.setId(BigInteger.valueOf(30));
task.setDocumentId(documentId);
task.setKnowledgeId(knowledgeId);
task.setStatus(DocumentImportTaskStatus.RUNNING.name());
task.setErrorSummary("旧错误");
tech.easyflow.ai.entity.Document inputDocument = new tech.easyflow.ai.entity.Document();
inputDocument.setId(documentId);
inputDocument.setCollectionId(knowledgeId);
Method method = KnowledgeDocumentImportTaskAppService.class.getDeclaredMethod(
"markIndexFailed",
DocumentImportTask.class,
tech.easyflow.ai.entity.Document.class,
String.class
);
method.setAccessible(true);
method.invoke(service, task, inputDocument, "新错误");
tech.easyflow.ai.entity.Document updatedDocument = updatedDocumentRef.get();
Assert.assertNotNull(updatedDocument);
Assert.assertEquals(DocumentProcessStatus.INDEX_FAILED.name(), updatedDocument.getProcessStatus());
Assert.assertEquals(Integer.valueOf(0), updatedDocument.getCompletedChunks());
Assert.assertEquals(Integer.valueOf(8), updatedDocument.getFailedChunks());
Assert.assertEquals(Integer.valueOf(0), updatedDocument.getProgressPercent());
Assert.assertEquals("新错误", updatedDocument.getLastTaskError());
DocumentImportTask updatedTask = updatedTaskRef.get();
Assert.assertNotNull(updatedTask);
Assert.assertEquals(DocumentImportTaskStatus.FAILED.name(), updatedTask.getStatus());
Assert.assertEquals("新错误", updatedTask.getErrorSummary());
}
private static DocumentMapper mockDocumentMapper(tech.easyflow.ai.entity.Document persistedDocument,
AtomicReference<tech.easyflow.ai.entity.Document> updatedDocumentRef) {
return (DocumentMapper) Proxy.newProxyInstance(
DocumentMapper.class.getClassLoader(),
new Class<?>[]{DocumentMapper.class},
(proxy, method, args) -> {
if ("selectOneById".equals(method.getName())) {
return persistedDocument;
}
if ("update".equals(method.getName())) {
updatedDocumentRef.set((tech.easyflow.ai.entity.Document) args[0]);
return 1;
}
return defaultValue(method.getReturnType());
}
);
}
private static DocumentImportTaskService mockDocumentImportTaskService(AtomicReference<DocumentImportTask> updatedTaskRef) {
return (DocumentImportTaskService) Proxy.newProxyInstance(
DocumentImportTaskService.class.getClassLoader(),
new Class<?>[]{DocumentImportTaskService.class},
(proxy, method, args) -> {
if ("updateById".equals(method.getName())) {
updatedTaskRef.set((DocumentImportTask) args[0]);
return true;
}
return defaultValue(method.getReturnType());
}
);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = KnowledgeDocumentImportTaskAppService.class.getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static Object defaultValue(Class<?> returnType) {
if (returnType == boolean.class) {
return false;
}
if (returnType == int.class) {
return 0;
}
if (returnType == long.class) {
return 0L;
}
return null;
}
/**
* 测试用 SSE 推送桩,避免依赖线程池和真实推送。
*/
private static class NoopTaskStatusStreamService extends DocumentImportTaskStatusStreamService {
@Override
public void publishAfterCommit(BigInteger documentId) {
// no-op
}
}
}

View File

@@ -0,0 +1,238 @@
package tech.easyflow.ai.service.impl;
import com.easyagents.core.document.Document;
import com.easyagents.search.engine.service.DocumentSearcher;
import com.easyagents.search.engine.service.KeywordSearchRequest;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.factory.ObjectProvider;
import tech.easyflow.ai.config.SearcherFactory;
import tech.easyflow.ai.enums.DocumentProcessStatus;
import tech.easyflow.ai.mapper.DocumentChunkMapper;
import tech.easyflow.ai.mapper.DocumentMapper;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.lang.reflect.Proxy;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static tech.easyflow.ai.entity.DocumentCollection.KEY_DOC_RECALL_MAX_NUM;
import static tech.easyflow.ai.entity.DocumentCollection.KEY_SIMILARITY_THRESHOLD;
/**
* {@link DocumentCollectionServiceImpl} 回归测试。
*
* @author Codex
* @since 2026-04-15
*/
public class DocumentCollectionServiceImplTest {
/**
* 验证检索结果会在重排前过滤掉未完成文档,避免高分进行中文档挤占最终名额。
*
* @throws Exception 反射注入异常
*/
@Test
public void searchShouldFilterNonCompletedChunksBeforeFinalTopK() throws Exception {
BigInteger knowledgeId = BigInteger.ONE;
BigInteger completedDocumentId = BigInteger.valueOf(101);
BigInteger indexingDocumentId = BigInteger.valueOf(102);
BigInteger completedChunkId = BigInteger.valueOf(1001);
BigInteger indexingChunkId = BigInteger.valueOf(1002);
tech.easyflow.ai.entity.DocumentCollection collection = new tech.easyflow.ai.entity.DocumentCollection();
collection.setId(knowledgeId);
collection.setCollectionType(tech.easyflow.ai.entity.DocumentCollection.TYPE_DOCUMENT);
collection.setOptions(new HashMap<String, Object>() {{
put(KEY_DOC_RECALL_MAX_NUM, 1);
put(KEY_SIMILARITY_THRESHOLD, BigDecimal.ZERO);
}});
tech.easyflow.ai.entity.DocumentChunk completedChunk = new tech.easyflow.ai.entity.DocumentChunk();
completedChunk.setId(completedChunkId);
completedChunk.setDocumentId(completedDocumentId);
completedChunk.setDocumentCollectionId(knowledgeId);
completedChunk.setContent("completed chunk");
tech.easyflow.ai.entity.DocumentChunk indexingChunk = new tech.easyflow.ai.entity.DocumentChunk();
indexingChunk.setId(indexingChunkId);
indexingChunk.setDocumentId(indexingDocumentId);
indexingChunk.setDocumentCollectionId(knowledgeId);
indexingChunk.setContent("indexing chunk");
tech.easyflow.ai.entity.Document completedDocument = new tech.easyflow.ai.entity.Document();
completedDocument.setId(completedDocumentId);
completedDocument.setCollectionId(knowledgeId);
completedDocument.setProcessStatus(DocumentProcessStatus.COMPLETED.name());
completedDocument.setTitle("completed");
TestKeywordSearcher searcher = new TestKeywordSearcher(List.of(
buildHit(indexingChunkId, 0.99D),
buildHit(completedChunkId, 0.75D)
));
DocumentCollectionServiceImpl service = new TestDocumentCollectionService(collection);
setField(service, "searcherFactory", new SearcherFactory(new StaticObjectProvider<DocumentSearcher>(searcher)));
setField(service, "documentChunkMapper", mockDocumentChunkMapper(completedChunk, indexingChunk));
setField(service, "documentMapper", mockDocumentMapper(completedDocument));
tech.easyflow.ai.rag.KnowledgeRetrievalRequest request = new tech.easyflow.ai.rag.KnowledgeRetrievalRequest();
request.setKnowledgeId(knowledgeId);
request.setQuery("test-query");
request.setRetrievalMode(com.easyagents.rag.retrieval.RetrievalMode.KEYWORD);
List<Document> result = service.search(request);
Assert.assertEquals("内部关键词召回应扩容到业务 topK 的 5 倍", 5, searcher.lastRequestCount);
Assert.assertEquals("知识库过滤后只应保留完成态文档", 1, result.size());
Assert.assertEquals(completedChunkId, result.get(0).getId());
Assert.assertEquals("completed chunk", result.get(0).getContent());
Assert.assertEquals(String.valueOf(knowledgeId), searcher.lastKnowledgeId);
}
private static Document buildHit(BigInteger id, double score) {
Document document = new Document();
document.setId(id);
document.setScore(score);
document.setContent("raw-hit-" + id);
return document;
}
private static DocumentChunkMapper mockDocumentChunkMapper(tech.easyflow.ai.entity.DocumentChunk... chunks) {
Map<String, tech.easyflow.ai.entity.DocumentChunk> chunkMap = new HashMap<String, tech.easyflow.ai.entity.DocumentChunk>();
for (tech.easyflow.ai.entity.DocumentChunk chunk : chunks) {
chunkMap.put(String.valueOf(chunk.getId()), chunk);
}
return (DocumentChunkMapper) Proxy.newProxyInstance(
DocumentChunkMapper.class.getClassLoader(),
new Class<?>[]{DocumentChunkMapper.class},
(proxy, method, args) -> {
if ("selectListByQuery".equals(method.getName())) {
return List.copyOf(chunkMap.values());
}
return defaultValue(method.getReturnType());
}
);
}
private static DocumentMapper mockDocumentMapper(tech.easyflow.ai.entity.Document completedDocument) {
return (DocumentMapper) Proxy.newProxyInstance(
DocumentMapper.class.getClassLoader(),
new Class<?>[]{DocumentMapper.class},
(proxy, method, args) -> {
if ("selectListByQuery".equals(method.getName())) {
return List.of(completedDocument);
}
return defaultValue(method.getReturnType());
}
);
}
private static void setField(Object target, String fieldName, Object value) throws Exception {
Field field = DocumentCollectionServiceImpl.class.getDeclaredField(fieldName);
field.setAccessible(true);
field.set(target, value);
}
private static Object defaultValue(Class<?> returnType) {
if (returnType == boolean.class) {
return false;
}
if (returnType == int.class) {
return 0;
}
if (returnType == long.class) {
return 0L;
}
return null;
}
/**
* 固定返回测试知识库实体,避免依赖数据库。
*/
private static class TestDocumentCollectionService extends DocumentCollectionServiceImpl {
private final tech.easyflow.ai.entity.DocumentCollection collection;
private TestDocumentCollectionService(tech.easyflow.ai.entity.DocumentCollection collection) {
this.collection = collection;
}
@Override
public tech.easyflow.ai.entity.DocumentCollection getById(Serializable id) {
return collection;
}
}
/**
* 记录关键词检索请求参数的搜索器桩实现。
*/
private static class TestKeywordSearcher implements DocumentSearcher {
private final List<Document> documents;
private int lastRequestCount;
private String lastKnowledgeId;
private TestKeywordSearcher(List<Document> documents) {
this.documents = documents;
}
@Override
public boolean addDocument(Document document) {
return true;
}
@Override
public boolean deleteDocument(Object id) {
return true;
}
@Override
public boolean updateDocument(Document document) {
return true;
}
@Override
public List<Document> searchDocuments(KeywordSearchRequest request) {
this.lastRequestCount = request.getCount();
this.lastKnowledgeId = request.getKnowledgeId();
return documents;
}
}
/**
* 最小 ObjectProvider 实现,仅服务搜索器工厂测试注入。
*/
private static class StaticObjectProvider<T> implements ObjectProvider<T> {
private final T value;
private StaticObjectProvider(T value) {
this.value = value;
}
@Override
public T getObject(Object... args) {
return value;
}
@Override
public T getIfAvailable() {
return value;
}
@Override
public T getIfUnique() {
return value;
}
@Override
public T getObject() {
return value;
}
}
}