Compare commits
11 Commits
72df00f25b
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 7591eb8cda | |||
| ef4528a441 | |||
| cb379e071c | |||
| 8b80770960 | |||
| c316eff5be | |||
| 1ea863cb2c | |||
| 0f4d10c43c | |||
| cc3bb9cff0 | |||
| e39f7521e2 | |||
| 1c205c3720 | |||
| 11e595b088 |
31
Dockerfile
31
Dockerfile
@@ -1,3 +1,4 @@
|
|||||||
|
# 后端构建脚本
|
||||||
FROM --platform=linux/amd64 swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/eclipse-temurin:17-jre
|
FROM --platform=linux/amd64 swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/eclipse-temurin:17-jre
|
||||||
|
|
||||||
ENV LANG=C.UTF-8
|
ENV LANG=C.UTF-8
|
||||||
@@ -8,12 +9,40 @@ ENV EASYFLOW_JAR_PATH=/app/artifacts/easyflow.jar
|
|||||||
ENV EASYFLOW_CONFIG_PATH=file:/app/application.yml
|
ENV EASYFLOW_CONFIG_PATH=file:/app/application.yml
|
||||||
ENV EASYFLOW_LOG_FILE=/app/logs/app.log
|
ENV EASYFLOW_LOG_FILE=/app/logs/app.log
|
||||||
ENV EASYFLOW_JAR_RESTART_GRACE_SECONDS=30
|
ENV EASYFLOW_JAR_RESTART_GRACE_SECONDS=30
|
||||||
|
ENV NPM_CONFIG_REGISTRY=https://registry.npmmirror.com
|
||||||
|
ENV PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
ENV PIP_TRUSTED_HOST=pypi.tuna.tsinghua.edu.cn
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN useradd --system --create-home easyflow && \
|
RUN useradd --system --create-home easyflow && \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends python3 inotify-tools tini && \
|
apt-get install -y --no-install-recommends \
|
||||||
|
ca-certificates \
|
||||||
|
curl \
|
||||||
|
gnupg && \
|
||||||
|
mkdir -p /etc/apt/keyrings && \
|
||||||
|
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key -o /tmp/nodesource.gpg.key && \
|
||||||
|
gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg /tmp/nodesource.gpg.key && \
|
||||||
|
chmod 644 /etc/apt/keyrings/nodesource.gpg && \
|
||||||
|
printf "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_24.x nodistro main\n" > /etc/apt/sources.list.d/nodesource.list && \
|
||||||
|
rm -f /tmp/nodesource.gpg.key && \
|
||||||
|
apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
inotify-tools \
|
||||||
|
nodejs \
|
||||||
|
python3 \
|
||||||
|
python3-pip \
|
||||||
|
python3-venv \
|
||||||
|
tini && \
|
||||||
|
ln -sf /usr/bin/python3 /usr/local/bin/python && \
|
||||||
|
ln -sf /usr/bin/pip3 /usr/local/bin/pip && \
|
||||||
|
npm config set registry "${NPM_CONFIG_REGISTRY}" && \
|
||||||
|
printf "registry=%s\n" "${NPM_CONFIG_REGISTRY}" > /etc/npmrc && \
|
||||||
|
npm install -g pnpm@10.17.1 && \
|
||||||
|
pnpm config set registry "${NPM_CONFIG_REGISTRY}" && \
|
||||||
|
mkdir -p /etc/pip && \
|
||||||
|
printf "[global]\nindex-url = %s\ntrusted-host = %s\n" "${PIP_INDEX_URL}" "${PIP_TRUSTED_HOST}" > /etc/pip.conf && \
|
||||||
rm -rf /var/lib/apt/lists/* && \
|
rm -rf /var/lib/apt/lists/* && \
|
||||||
mkdir -p /app/logs /app/artifacts /app/data && \
|
mkdir -p /app/logs /app/artifacts /app/data && \
|
||||||
chown -R easyflow:easyflow /app
|
chown -R easyflow:easyflow /app
|
||||||
|
|||||||
61
config/proguard/common-keep.pro
Normal file
61
config/proguard/common-keep.pro
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
-dontshrink
|
||||||
|
-dontoptimize
|
||||||
|
-dontpreverify
|
||||||
|
-ignorewarnings
|
||||||
|
-dontnote
|
||||||
|
|
||||||
|
-libraryjars <java.home>/jmods/java.base.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.compiler.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.datatransfer.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.desktop.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.instrument.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.logging.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.management.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.naming.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.net.http.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.prefs.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.rmi.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.scripting.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.security.jgss.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.security.sasl.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.sql.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.transaction.xa.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.xml.jmod(!**.jar;!module-info.class)
|
||||||
|
-libraryjars <java.home>/jmods/java.xml.crypto.jmod(!**.jar;!module-info.class)
|
||||||
|
|
||||||
|
-keepattributes RuntimeVisibleAnnotations,RuntimeInvisibleAnnotations,RuntimeVisibleParameterAnnotations,RuntimeInvisibleParameterAnnotations,AnnotationDefault,Signature,InnerClasses,EnclosingMethod,Record,SourceFile,LineNumberTable,MethodParameters
|
||||||
|
|
||||||
|
-keep @org.springframework.stereotype.Controller class * { *; }
|
||||||
|
-keep @org.springframework.web.bind.annotation.RestController class * { *; }
|
||||||
|
-keep @org.springframework.context.annotation.Configuration class * { *; }
|
||||||
|
-keep @org.springframework.boot.context.properties.ConfigurationProperties class * { *; }
|
||||||
|
-keep @org.springframework.boot.autoconfigure.SpringBootApplication class * { *; }
|
||||||
|
|
||||||
|
-keep class **.*Controller { *; }
|
||||||
|
-keep class **.*Mapper { *; }
|
||||||
|
-keep class **.mapper.** { *; }
|
||||||
|
-keep class **.entity.** { *; }
|
||||||
|
-keep class **.dto.** { *; }
|
||||||
|
-keep class **.vo.** { *; }
|
||||||
|
-keep class **.model.** { *; }
|
||||||
|
-keep class **.config.** { *; }
|
||||||
|
-keep class **.enums.** { *; }
|
||||||
|
-keep class **.annotation.** { *; }
|
||||||
|
-keep class **.*Exception { *; }
|
||||||
|
-keep class **.*ErrorCode { *; }
|
||||||
|
-keep class **.*Properties { *; }
|
||||||
|
-keep class **.*Config { *; }
|
||||||
|
-keep class **.*Configuration { *; }
|
||||||
|
-keep interface tech.easyflow.** { *; }
|
||||||
|
-keep enum tech.easyflow.** { *; }
|
||||||
|
|
||||||
|
-keepclassmembers class * {
|
||||||
|
@jakarta.annotation.Resource <fields>;
|
||||||
|
@org.springframework.beans.factory.annotation.Autowired <fields>;
|
||||||
|
@org.springframework.beans.factory.annotation.Value <fields>;
|
||||||
|
@org.springframework.context.annotation.Bean <methods>;
|
||||||
|
}
|
||||||
|
|
||||||
|
-keepclassmembers class * {
|
||||||
|
public <init>(...);
|
||||||
|
}
|
||||||
28
config/proguard/easyflow-module-ai.pro
Normal file
28
config/proguard/easyflow-module-ai.pro
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
-include ../../config/proguard/common-keep.pro
|
||||||
|
|
||||||
|
-keep class tech.easyflow.ai.chattime.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.constants.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.document.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.documentimport.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.easyagents.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.exception.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.mcp.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.node.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.permission.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.plugin.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.publish.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.rag.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.service.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.support.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.utils.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.service.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.model.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.protocol.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.exception.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.mapper.OpenAiProtocolMapper { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.provider.ModelProviderGateway { *; }
|
||||||
|
-keep class tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver { *; }
|
||||||
|
-keep class tech.easyflow.ai.easyagentsflow.config.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.easyagentsflow.entity.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.easyagentsflow.service.** { *; }
|
||||||
|
-keep class tech.easyflow.ai.easyagentsflow.support.** { *; }
|
||||||
5
config/proguard/easyflow-module-autoconfig.pro
Normal file
5
config/proguard/easyflow-module-autoconfig.pro
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
-include ../../config/proguard/common-keep.pro
|
||||||
|
|
||||||
|
-keep class tech.easyflow.autoconfig.license.EasyflowLicenseBootstrapValidator { *; }
|
||||||
|
-keep class tech.easyflow.autoconfig.license.EasyflowLicenseProperties { *; }
|
||||||
|
-keep class tech.easyflow.autoconfig.license.EasyflowLicenseVerificationResult { *; }
|
||||||
10
config/proguard/easyflow-module-datacenter.pro
Normal file
10
config/proguard/easyflow-module-datacenter.pro
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
-include ../../config/proguard/common-keep.pro
|
||||||
|
|
||||||
|
-keep class tech.easyflow.datacenter.connector.DatacenterConnector { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.connector.QueryExecutor { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.connector.WriteExecutor { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.connector.MetadataExplorer { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.connector.SourceHealthChecker { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.connector.SqlDialect { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.execution.model.** { *; }
|
||||||
|
-keep class tech.easyflow.datacenter.meta.enums.** { *; }
|
||||||
@@ -78,7 +78,7 @@ public class AgentCategoryController extends BaseCurdController<AgentCategorySer
|
|||||||
for (Serializable id : ids) {
|
for (Serializable id : ids) {
|
||||||
QueryWrapper queryWrapper = QueryWrapper.create().eq(Agent::getCategoryId, id);
|
QueryWrapper queryWrapper = QueryWrapper.create().eq(Agent::getCategoryId, id);
|
||||||
List<Agent> agents = agentMapper.selectListByQuery(queryWrapper);
|
List<Agent> agents = agentMapper.selectListByQuery(queryWrapper);
|
||||||
if (!agents.isEmpty()) {
|
if (agents != null && !agents.isEmpty()) {
|
||||||
throw new BusinessException("请先删除该分类下的所有 Agent");
|
throw new BusinessException("请先删除该分类下的所有 Agent");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import tech.easyflow.common.satoken.util.SaTokenUtil;
|
|||||||
import tech.easyflow.common.web.jsonbody.JsonBody;
|
import tech.easyflow.common.web.jsonbody.JsonBody;
|
||||||
|
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent 管理端会话控制器。
|
* Agent 管理端会话控制器。
|
||||||
@@ -104,6 +105,19 @@ public class AgentSessionController {
|
|||||||
return Result.ok();
|
return Result.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存 Agent 会话临时知识库。
|
||||||
|
*
|
||||||
|
* @param sessionId 会话 ID
|
||||||
|
* @param knowledgeIds 临时知识库 ID
|
||||||
|
* @return 操作结果
|
||||||
|
*/
|
||||||
|
@PostMapping("/{sessionId}/extraKnowledges")
|
||||||
|
public Result<ChatWorkspaceSessionDetailView> saveExtraKnowledges(@PathVariable BigInteger sessionId,
|
||||||
|
@JsonBody(value = "knowledgeIds") List<BigInteger> knowledgeIds) {
|
||||||
|
return Result.ok(agentSessionService.saveCurrentUserExtraKnowledges(currentAccount(), sessionId, knowledgeIds));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除 Agent 会话。
|
* 删除 Agent 会话。
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import tech.easyflow.ai.entity.Model;
|
|||||||
import tech.easyflow.ai.service.DocumentChunkService;
|
import tech.easyflow.ai.service.DocumentChunkService;
|
||||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.common.annotation.UsePermission;
|
import tech.easyflow.common.annotation.UsePermission;
|
||||||
import tech.easyflow.common.domain.Result;
|
import tech.easyflow.common.domain.Result;
|
||||||
import tech.easyflow.common.web.controller.BaseCurdController;
|
import tech.easyflow.common.web.controller.BaseCurdController;
|
||||||
@@ -93,22 +94,26 @@ public class DocumentChunkController extends BaseCurdController<DocumentChunkSer
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(2, "知识库没有配置向量库");
|
return Result.fail(2, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
// 设置向量模型
|
try {
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
// 设置向量模型
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
return Result.fail(3, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
|
return Result.fail(3, "知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
Document document = Document.of(documentChunk.getContent());
|
||||||
|
document.setId(documentChunk.getId());
|
||||||
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
|
metadata.put("keywords", documentChunk.getMetadataKeyWords());
|
||||||
|
metadata.put("questions", documentChunk.getMetadataQuestions());
|
||||||
|
document.setMetadataMap(metadata);
|
||||||
|
StoreResult result = documentStore.update(document, options); // 更新已有记录
|
||||||
|
return Result.ok(result);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
Document document = Document.of(documentChunk.getContent());
|
|
||||||
document.setId(documentChunk.getId());
|
|
||||||
Map<String, Object> metadata = new HashMap<>();
|
|
||||||
metadata.put("keywords", documentChunk.getMetadataKeyWords());
|
|
||||||
metadata.put("questions", documentChunk.getMetadataQuestions());
|
|
||||||
document.setMetadataMap(metadata);
|
|
||||||
StoreResult result = documentStore.update(document, options); // 更新已有记录
|
|
||||||
return Result.ok(result);
|
|
||||||
}
|
}
|
||||||
return Result.ok(false);
|
return Result.ok(false);
|
||||||
}
|
}
|
||||||
@@ -135,19 +140,23 @@ public class DocumentChunkController extends BaseCurdController<DocumentChunkSer
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(3, "知识库没有配置向量库");
|
return Result.fail(3, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
// 设置向量模型
|
try {
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
// 设置向量模型
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
return Result.fail(4, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
}
|
return Result.fail(4, "知识库没有配置向量模型");
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
}
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
List<BigInteger> deleteList = new ArrayList<>();
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
deleteList.add(chunkId);
|
List<BigInteger> deleteList = new ArrayList<>();
|
||||||
documentStore.delete(deleteList, options);
|
deleteList.add(chunkId);
|
||||||
documentChunkService.removeChunk(knowledge, chunkId);
|
documentStore.delete(deleteList, options);
|
||||||
|
documentChunkService.removeChunk(knowledge, chunkId);
|
||||||
|
|
||||||
return super.remove(chunkId);
|
return super.remove(chunkId);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package tech.easyflow.admin.controller.ai;
|
package tech.easyflow.admin.controller.ai;
|
||||||
|
|
||||||
|
import com.easyagents.mcp.client.McpEnvironmentCheckResult;
|
||||||
import com.mybatisflex.core.paginate.Page;
|
import com.mybatisflex.core.paginate.Page;
|
||||||
import com.mybatisflex.core.query.QueryWrapper;
|
import com.mybatisflex.core.query.QueryWrapper;
|
||||||
import jakarta.servlet.http.HttpServletRequest;
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
@@ -64,6 +65,11 @@ public class McpController extends BaseCurdController<McpService, Mcp> {
|
|||||||
return Result.ok(service.getMcpTools(id));
|
return Result.ok(service.getMcpTools(id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@PostMapping("/check")
|
||||||
|
public Result<McpEnvironmentCheckResult> check(@JsonBody("configJson") String configJson) {
|
||||||
|
return Result.ok(service.checkMcp(configJson));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@GetMapping("pageTools")
|
@GetMapping("pageTools")
|
||||||
public Result<Page<Mcp>> pageTools(HttpServletRequest request, String sortKey, String sortType, Long pageNumber, Long pageSize) {
|
public Result<Page<Mcp>> pageTools(HttpServletRequest request, String sortKey, String sortType, Long pageNumber, Long pageSize) {
|
||||||
@@ -80,4 +86,4 @@ public class McpController extends BaseCurdController<McpService, Mcp> {
|
|||||||
|
|
||||||
return Result.ok(service.pageTools(mcpPage));
|
return Result.ok(service.pageTools(mcpPage));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ import tech.easyflow.ai.service.KnowledgeEmbeddingService;
|
|||||||
import tech.easyflow.ai.service.KnowledgeShareAuditService;
|
import tech.easyflow.ai.service.KnowledgeShareAuditService;
|
||||||
import tech.easyflow.ai.service.KnowledgeShareService;
|
import tech.easyflow.ai.service.KnowledgeShareService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.ai.vo.FaqImportResultVo;
|
import tech.easyflow.ai.vo.FaqImportResultVo;
|
||||||
import tech.easyflow.ai.vo.KnowledgeShareAuthContext;
|
import tech.easyflow.ai.vo.KnowledgeShareAuthContext;
|
||||||
import tech.easyflow.ai.vo.KnowledgeShareViewDetail;
|
import tech.easyflow.ai.vo.KnowledgeShareViewDetail;
|
||||||
@@ -520,19 +521,23 @@ public class ShareKnowledgeController {
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(2, "知识库没有配置向量库");
|
return Result.fail(2, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(context.getKnowledge().getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(context.getKnowledge().getVectorEmbedModelId());
|
||||||
return Result.fail(3, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
|
return Result.fail(3, "知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(context.getKnowledge().getVectorStoreCollection());
|
||||||
|
com.easyagents.core.document.Document doc = com.easyagents.core.document.Document.of(documentChunk.getContent());
|
||||||
|
doc.setId(documentChunk.getId());
|
||||||
|
StoreResult result = documentStore.update(doc, options);
|
||||||
|
audit(context, "更新分享文档 Chunk", "KNOWLEDGE_SHARE_URL_WRITE", true,
|
||||||
|
auditDetail("knowledgeId", context.getKnowledge().getId(), "chunkId", documentChunk.getId()));
|
||||||
|
return Result.ok(result);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(context.getKnowledge().getVectorStoreCollection());
|
|
||||||
com.easyagents.core.document.Document doc = com.easyagents.core.document.Document.of(documentChunk.getContent());
|
|
||||||
doc.setId(documentChunk.getId());
|
|
||||||
StoreResult result = documentStore.update(doc, options);
|
|
||||||
audit(context, "更新分享文档 Chunk", "KNOWLEDGE_SHARE_URL_WRITE", true,
|
|
||||||
auditDetail("knowledgeId", context.getKnowledge().getId(), "chunkId", documentChunk.getId()));
|
|
||||||
return Result.ok(result);
|
|
||||||
}
|
}
|
||||||
return Result.ok(false);
|
return Result.ok(false);
|
||||||
}
|
}
|
||||||
@@ -559,17 +564,21 @@ public class ShareKnowledgeController {
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(2, "知识库没有配置向量库");
|
return Result.fail(2, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(context.getKnowledge().getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(context.getKnowledge().getVectorEmbedModelId());
|
||||||
return Result.fail(3, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
|
return Result.fail(3, "知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(context.getKnowledge().getVectorStoreCollection());
|
||||||
|
documentStore.delete(Collections.singletonList(chunkId), options);
|
||||||
|
documentChunkService.removeById(chunkId);
|
||||||
|
audit(context, "删除分享文档 Chunk", "KNOWLEDGE_SHARE_URL_WRITE", true,
|
||||||
|
auditDetail("knowledgeId", context.getKnowledge().getId(), "chunkId", chunkId));
|
||||||
|
return Result.ok(true);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(context.getKnowledge().getVectorStoreCollection());
|
|
||||||
documentStore.delete(Collections.singletonList(chunkId), options);
|
|
||||||
documentChunkService.removeById(chunkId);
|
|
||||||
audit(context, "删除分享文档 Chunk", "KNOWLEDGE_SHARE_URL_WRITE", true,
|
|
||||||
auditDetail("knowledgeId", context.getKnowledge().getId(), "chunkId", chunkId));
|
|
||||||
return Result.ok(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -10,13 +10,12 @@ import tech.easyflow.agent.service.AgentService;
|
|||||||
import tech.easyflow.ai.entity.DocumentCollection;
|
import tech.easyflow.ai.entity.DocumentCollection;
|
||||||
import tech.easyflow.ai.enums.PublishStatus;
|
import tech.easyflow.ai.enums.PublishStatus;
|
||||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
import tech.easyflow.chatlog.domain.dto.ChatHistoryPage;
|
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
import tech.easyflow.chatlog.domain.dto.*;
|
||||||
import tech.easyflow.chatlog.domain.dto.ChatSessionPage;
|
|
||||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
|
||||||
import tech.easyflow.chatlog.domain.query.ChatPageQuery;
|
import tech.easyflow.chatlog.domain.query.ChatPageQuery;
|
||||||
import tech.easyflow.chatlog.service.ChatSessionCommandService;
|
import tech.easyflow.chatlog.service.ChatSessionCommandService;
|
||||||
import tech.easyflow.chatlog.service.ChatSessionQueryService;
|
import tech.easyflow.chatlog.service.ChatSessionQueryService;
|
||||||
|
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||||
import tech.easyflow.common.entity.LoginAccount;
|
import tech.easyflow.common.entity.LoginAccount;
|
||||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
import tech.easyflow.system.enums.CategoryResourceType;
|
import tech.easyflow.system.enums.CategoryResourceType;
|
||||||
@@ -40,6 +39,7 @@ public class AgentSessionService {
|
|||||||
private final DocumentCollectionService documentCollectionService;
|
private final DocumentCollectionService documentCollectionService;
|
||||||
private final ResourceAccessService resourceAccessService;
|
private final ResourceAccessService resourceAccessService;
|
||||||
private final AgentRuntimeStateCleanupService agentRuntimeStateCleanupService;
|
private final AgentRuntimeStateCleanupService agentRuntimeStateCleanupService;
|
||||||
|
private final ChatJsonSupport chatJsonSupport;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建 Agent 管理端会话服务。
|
* 创建 Agent 管理端会话服务。
|
||||||
@@ -50,19 +50,22 @@ public class AgentSessionService {
|
|||||||
* @param documentCollectionService 知识库服务
|
* @param documentCollectionService 知识库服务
|
||||||
* @param resourceAccessService 资源访问服务
|
* @param resourceAccessService 资源访问服务
|
||||||
* @param agentRuntimeStateCleanupService Agent 运行态清理服务
|
* @param agentRuntimeStateCleanupService Agent 运行态清理服务
|
||||||
|
* @param chatJsonSupport 聊天 JSON 工具
|
||||||
*/
|
*/
|
||||||
public AgentSessionService(ChatSessionQueryService chatSessionQueryService,
|
public AgentSessionService(ChatSessionQueryService chatSessionQueryService,
|
||||||
ChatSessionCommandService chatSessionCommandService,
|
ChatSessionCommandService chatSessionCommandService,
|
||||||
AgentService agentService,
|
AgentService agentService,
|
||||||
DocumentCollectionService documentCollectionService,
|
DocumentCollectionService documentCollectionService,
|
||||||
ResourceAccessService resourceAccessService,
|
ResourceAccessService resourceAccessService,
|
||||||
AgentRuntimeStateCleanupService agentRuntimeStateCleanupService) {
|
AgentRuntimeStateCleanupService agentRuntimeStateCleanupService,
|
||||||
|
ChatJsonSupport chatJsonSupport) {
|
||||||
this.chatSessionQueryService = chatSessionQueryService;
|
this.chatSessionQueryService = chatSessionQueryService;
|
||||||
this.chatSessionCommandService = chatSessionCommandService;
|
this.chatSessionCommandService = chatSessionCommandService;
|
||||||
this.agentService = agentService;
|
this.agentService = agentService;
|
||||||
this.documentCollectionService = documentCollectionService;
|
this.documentCollectionService = documentCollectionService;
|
||||||
this.resourceAccessService = resourceAccessService;
|
this.resourceAccessService = resourceAccessService;
|
||||||
this.agentRuntimeStateCleanupService = agentRuntimeStateCleanupService;
|
this.agentRuntimeStateCleanupService = agentRuntimeStateCleanupService;
|
||||||
|
this.chatJsonSupport = chatJsonSupport;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -103,6 +106,12 @@ public class AgentSessionService {
|
|||||||
Agent displayAgent = availability == null ? null : availability.displayAgent();
|
Agent displayAgent = availability == null ? null : availability.displayAgent();
|
||||||
detail.setAssistant(toAssistantView(displayAgent, summary));
|
detail.setAssistant(toAssistantView(displayAgent, summary));
|
||||||
detail.setBoundKnowledges(resolveBoundKnowledges(displayAgent));
|
detail.setBoundKnowledges(resolveBoundKnowledges(displayAgent));
|
||||||
|
ExtraKnowledgeResolution extraKnowledgeResolution = resolveExtraKnowledges(summary);
|
||||||
|
detail.setExtraKnowledges(extraKnowledgeResolution.validKnowledges());
|
||||||
|
detail.setRemovedExtraKnowledgeNames(extraKnowledgeResolution.removedNames());
|
||||||
|
if (extraKnowledgeResolution.shouldSync()) {
|
||||||
|
syncSessionExtraKnowledges(summary, extraKnowledgeResolution.validKnowledgeIds(), account.getId());
|
||||||
|
}
|
||||||
return detail;
|
return detail;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,6 +159,26 @@ public class AgentSessionService {
|
|||||||
chatSessionCommandService.renameSession(sessionId, account.getId(), title.trim(), account.getId());
|
chatSessionCommandService.renameSession(sessionId, account.getId(), title.trim(), account.getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存当前用户 Agent 会话的临时知识库。
|
||||||
|
*
|
||||||
|
* @param account 当前登录账号
|
||||||
|
* @param sessionId 会话 ID
|
||||||
|
* @param knowledgeIds 临时知识库 ID
|
||||||
|
* @return 更新后的会话详情
|
||||||
|
*/
|
||||||
|
public ChatWorkspaceSessionDetailView saveCurrentUserExtraKnowledges(LoginAccount account,
|
||||||
|
BigInteger sessionId,
|
||||||
|
List<BigInteger> knowledgeIds) {
|
||||||
|
ChatSessionSummary summary = requireUserAgentSession(account, sessionId);
|
||||||
|
ExtraKnowledgeResolution resolution = resolveVisibleKnowledgeViews(normalizeExtraKnowledgeIds(knowledgeIds));
|
||||||
|
if (!resolution.removedNames().isEmpty()) {
|
||||||
|
throw new BusinessException("所选知识库已失效或无权限使用");
|
||||||
|
}
|
||||||
|
syncSessionExtraKnowledges(summary, resolution.validKnowledgeIds(), account.getId());
|
||||||
|
return getCurrentUserSession(account, sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除当前用户的 Agent 会话。
|
* 删除当前用户的 Agent 会话。
|
||||||
*
|
*
|
||||||
@@ -295,8 +324,97 @@ public class AgentSessionService {
|
|||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ExtraKnowledgeResolution resolveExtraKnowledges(ChatSessionSummary summary) {
|
||||||
|
ChatSessionExtPayload payload = chatJsonSupport.fromJson(summary.getExtJson(), ChatSessionExtPayload.class);
|
||||||
|
List<BigInteger> extraKnowledgeIds = payload == null ? List.of() : payload.getExtraKnowledgeIds();
|
||||||
|
return resolveVisibleKnowledgeViews(extraKnowledgeIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ExtraKnowledgeResolution resolveVisibleKnowledgeViews(List<BigInteger> knowledgeIds) {
|
||||||
|
if (knowledgeIds == null || knowledgeIds.isEmpty()) {
|
||||||
|
return new ExtraKnowledgeResolution(List.of(), List.of(), List.of(), false);
|
||||||
|
}
|
||||||
|
List<BigInteger> normalizedIds = normalizeExtraKnowledgeIds(knowledgeIds);
|
||||||
|
if (normalizedIds.isEmpty()) {
|
||||||
|
return new ExtraKnowledgeResolution(List.of(), List.of(), List.of(), false);
|
||||||
|
}
|
||||||
|
List<DocumentCollection> collections = documentCollectionService.listByIds(normalizedIds);
|
||||||
|
Map<BigInteger, DocumentCollection> collectionMap = new LinkedHashMap<>();
|
||||||
|
for (DocumentCollection collection : collections) {
|
||||||
|
collectionMap.put(collection.getId(), collection);
|
||||||
|
}
|
||||||
|
List<ChatWorkspaceKnowledgeView> validKnowledges = new ArrayList<>();
|
||||||
|
List<BigInteger> validKnowledgeIds = new ArrayList<>();
|
||||||
|
List<String> removedNames = new ArrayList<>();
|
||||||
|
boolean changed = false;
|
||||||
|
for (BigInteger knowledgeId : normalizedIds) {
|
||||||
|
DocumentCollection current = collectionMap.get(knowledgeId);
|
||||||
|
if (current == null) {
|
||||||
|
removedNames.add("知识库#" + knowledgeId);
|
||||||
|
changed = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (PublishStatus.from(current.getPublishStatus()) != PublishStatus.PUBLISHED) {
|
||||||
|
removedNames.add(current.getTitle());
|
||||||
|
changed = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!resourceAccessService.canAccess(CategoryResourceType.KNOWLEDGE, current, ResourceAction.USE)) {
|
||||||
|
removedNames.add(current.getTitle());
|
||||||
|
changed = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
validKnowledges.add(toKnowledgeView(documentCollectionService.toPublishedView(current)));
|
||||||
|
validKnowledgeIds.add(current.getId());
|
||||||
|
}
|
||||||
|
if (!Objects.equals(normalizedIds, validKnowledgeIds)) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
return new ExtraKnowledgeResolution(validKnowledges, validKnowledgeIds, removedNames, changed);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<BigInteger> normalizeExtraKnowledgeIds(List<BigInteger> knowledgeIds) {
|
||||||
|
if (knowledgeIds == null || knowledgeIds.isEmpty()) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
List<BigInteger> normalizedIds = new ArrayList<>();
|
||||||
|
for (BigInteger knowledgeId : knowledgeIds) {
|
||||||
|
if (knowledgeId != null && !normalizedIds.contains(knowledgeId)) {
|
||||||
|
normalizedIds.add(knowledgeId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (normalizedIds.size() > 3) {
|
||||||
|
throw new BusinessException("临时知识库最多选择 3 个");
|
||||||
|
}
|
||||||
|
return normalizedIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void syncSessionExtraKnowledges(ChatSessionSummary summary, List<BigInteger> validKnowledgeIds, BigInteger operatorId) {
|
||||||
|
ChatSessionExtPayload payload = new ChatSessionExtPayload();
|
||||||
|
payload.setExtraKnowledgeIds(validKnowledgeIds);
|
||||||
|
ChatSessionUpsertCommand command = new ChatSessionUpsertCommand();
|
||||||
|
command.setSessionId(summary.getId());
|
||||||
|
command.setTenantId(summary.getTenantId());
|
||||||
|
command.setDeptId(summary.getDeptId());
|
||||||
|
command.setUserId(summary.getUserId());
|
||||||
|
command.setUserAccount(summary.getUserAccount());
|
||||||
|
command.setAssistantId(summary.getAssistantId());
|
||||||
|
command.setAssistantCode(summary.getAssistantCode());
|
||||||
|
command.setAssistantName(summary.getAssistantName());
|
||||||
|
command.setTitle(summary.getTitle());
|
||||||
|
command.setExtJson(chatJsonSupport.toJson(payload));
|
||||||
|
command.setOperatorId(operatorId);
|
||||||
|
chatSessionCommandService.createOrTouchSession(command);
|
||||||
|
}
|
||||||
|
|
||||||
private record AgentAvailability(boolean continuable,
|
private record AgentAvailability(boolean continuable,
|
||||||
ChatWorkspaceReadOnlyReason reason,
|
ChatWorkspaceReadOnlyReason reason,
|
||||||
Agent displayAgent) {
|
Agent displayAgent) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private record ExtraKnowledgeResolution(List<ChatWorkspaceKnowledgeView> validKnowledges,
|
||||||
|
List<BigInteger> validKnowledgeIds,
|
||||||
|
List<String> removedNames,
|
||||||
|
boolean shouldSync) {
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import tech.easyflow.ai.service.KnowledgeShareAuditService;
|
|||||||
import tech.easyflow.ai.service.KnowledgeSharePermissionService;
|
import tech.easyflow.ai.service.KnowledgeSharePermissionService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
import tech.easyflow.ai.service.impl.KnowledgeSharePermissionServiceImpl;
|
import tech.easyflow.ai.service.impl.KnowledgeSharePermissionServiceImpl;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.ai.vo.FaqImportResultVo;
|
import tech.easyflow.ai.vo.FaqImportResultVo;
|
||||||
import tech.easyflow.common.domain.Result;
|
import tech.easyflow.common.domain.Result;
|
||||||
import tech.easyflow.common.filestorage.FileStorageService;
|
import tech.easyflow.common.filestorage.FileStorageService;
|
||||||
@@ -342,18 +343,22 @@ public class PublicKnowledgeShareController {
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(2, "知识库没有配置向量库");
|
return Result.fail(2, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
return Result.fail(3, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
|
return Result.fail(3, "知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
com.easyagents.core.document.Document doc = com.easyagents.core.document.Document.of(documentChunk.getContent());
|
||||||
|
doc.setId(current.getId());
|
||||||
|
StoreResult result = documentStore.update(doc, options);
|
||||||
|
audit(apiKey, "API更新文档 Chunk", "KNOWLEDGE_API_SHARE_WRITE", request.getRequestURI(), Map.of("knowledgeId", knowledgeId, "chunkId", documentChunk.getId()));
|
||||||
|
return Result.ok(result);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
com.easyagents.core.document.Document doc = com.easyagents.core.document.Document.of(documentChunk.getContent());
|
|
||||||
doc.setId(current.getId());
|
|
||||||
StoreResult result = documentStore.update(doc, options);
|
|
||||||
audit(apiKey, "API更新文档 Chunk", "KNOWLEDGE_API_SHARE_WRITE", request.getRequestURI(), Map.of("knowledgeId", knowledgeId, "chunkId", documentChunk.getId()));
|
|
||||||
return Result.ok(result);
|
|
||||||
}
|
}
|
||||||
return Result.ok(false);
|
return Result.ok(false);
|
||||||
}
|
}
|
||||||
@@ -376,16 +381,20 @@ public class PublicKnowledgeShareController {
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
return Result.fail(2, "知识库没有配置向量库");
|
return Result.fail(2, "知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
return Result.fail(3, "知识库没有配置向量模型");
|
if (model == null) {
|
||||||
|
return Result.fail(3, "知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
documentStore.delete(Collections.singletonList(chunkId), options);
|
||||||
|
documentChunkService.removeById(chunkId);
|
||||||
|
audit(apiKey, "API删除文档 Chunk", "KNOWLEDGE_API_SHARE_WRITE", request.getRequestURI(), Map.of("knowledgeId", knowledgeId, "chunkId", chunkId));
|
||||||
|
return Result.ok(true);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
documentStore.delete(Collections.singletonList(chunkId), options);
|
|
||||||
documentChunkService.removeById(chunkId);
|
|
||||||
audit(apiKey, "API删除文档 Chunk", "KNOWLEDGE_API_SHARE_WRITE", request.getRequestURI(), Map.of("knowledgeId", knowledgeId, "chunkId", chunkId));
|
|
||||||
return Result.ok(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -25,6 +25,11 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-autoconfigure</artifactId>
|
<artifactId>spring-boot-autoconfigure</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-actuator</artifactId>
|
||||||
|
<version>${spring-boot.version}</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.clickhouse</groupId>
|
<groupId>com.clickhouse</groupId>
|
||||||
<artifactId>clickhouse-jdbc</artifactId>
|
<artifactId>clickhouse-jdbc</artifactId>
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package tech.easyflow.common.analyticaldb.support;
|
||||||
|
|
||||||
|
import org.springframework.boot.actuate.health.Health;
|
||||||
|
import org.springframework.boot.actuate.health.HealthIndicator;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 分析数据库健康检查。
|
||||||
|
*/
|
||||||
|
@Component("analyticalDbHealthIndicator")
|
||||||
|
public class AnalyticalDBHealthIndicator implements HealthIndicator {
|
||||||
|
|
||||||
|
private final AnalyticalDBHealthSupport healthSupport;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建分析数据库健康检查器。
|
||||||
|
*
|
||||||
|
* @param healthSupport 分析数据库健康检查支持
|
||||||
|
*/
|
||||||
|
public AnalyticalDBHealthIndicator(AnalyticalDBHealthSupport healthSupport) {
|
||||||
|
this.healthSupport = healthSupport;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查分析数据库是否可用。
|
||||||
|
*
|
||||||
|
* @return 健康状态
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Health health() {
|
||||||
|
if (!healthSupport.enabled()) {
|
||||||
|
return Health.up().withDetail("enabled", false).build();
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
healthSupport.selfCheck();
|
||||||
|
return Health.up().withDetail("enabled", true).build();
|
||||||
|
} catch (Exception e) {
|
||||||
|
return Health.down(e).withDetail("enabled", true).build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package tech.easyflow.common.audio.config;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 音频模块线程池配置。
|
||||||
|
*/
|
||||||
|
@ConfigurationProperties(prefix = "easyflow.thread-pool.scheduler")
|
||||||
|
public class AudioThreadPoolProperties {
|
||||||
|
|
||||||
|
private int poolSize = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取调度线程池大小。
|
||||||
|
*
|
||||||
|
* @return 调度线程池大小
|
||||||
|
*/
|
||||||
|
public int getPoolSize() {
|
||||||
|
return poolSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置调度线程池大小。
|
||||||
|
*
|
||||||
|
* @param poolSize 调度线程池大小
|
||||||
|
*/
|
||||||
|
public void setPoolSize(int poolSize) {
|
||||||
|
this.poolSize = poolSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,19 +1,38 @@
|
|||||||
package tech.easyflow.common.audio.socket;
|
package tech.easyflow.common.audio.socket;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.scheduling.TaskScheduler;
|
import org.springframework.scheduling.TaskScheduler;
|
||||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
|
||||||
|
import tech.easyflow.common.audio.config.AudioThreadPoolProperties;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@EnableScheduling
|
@EnableScheduling
|
||||||
|
@EnableConfigurationProperties(AudioThreadPoolProperties.class)
|
||||||
public class SchedulingConfig {
|
public class SchedulingConfig {
|
||||||
|
|
||||||
|
private final AudioThreadPoolProperties properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建音频调度配置。
|
||||||
|
*
|
||||||
|
* @param properties 音频调度线程池配置
|
||||||
|
*/
|
||||||
|
public SchedulingConfig(AudioThreadPoolProperties properties) {
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建调度线程池。
|
||||||
|
*
|
||||||
|
* @return 调度线程池
|
||||||
|
*/
|
||||||
@Bean
|
@Bean
|
||||||
public TaskScheduler taskScheduler() {
|
public TaskScheduler taskScheduler() {
|
||||||
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
|
ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler();
|
||||||
scheduler.setPoolSize(10);
|
scheduler.setPoolSize(properties.getPoolSize());
|
||||||
scheduler.setThreadNamePrefix("scheduled-task-");
|
scheduler.setThreadNamePrefix("scheduled-task-");
|
||||||
scheduler.setDaemon(true);
|
scheduler.setDaemon(true);
|
||||||
scheduler.initialize();
|
scheduler.initialize();
|
||||||
|
|||||||
@@ -39,7 +39,23 @@
|
|||||||
<artifactId>fastjson</artifactId>
|
<artifactId>fastjson</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-aop</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<version>${junit.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>5.12.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package tech.easyflow.common.cache;
|
||||||
|
|
||||||
|
import java.lang.annotation.ElementType;
|
||||||
|
import java.lang.annotation.Retention;
|
||||||
|
import java.lang.annotation.RetentionPolicy;
|
||||||
|
import java.lang.annotation.Target;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Spring 定时任务 Redis 分布式锁。
|
||||||
|
*/
|
||||||
|
@Target(ElementType.METHOD)
|
||||||
|
@Retention(RetentionPolicy.RUNTIME)
|
||||||
|
public @interface DistributedScheduledLock {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取锁使用的 Redis key。
|
||||||
|
*
|
||||||
|
* @return Redis 锁 key
|
||||||
|
*/
|
||||||
|
String key();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 等待锁的秒数。
|
||||||
|
*
|
||||||
|
* @return 等待锁的秒数
|
||||||
|
*/
|
||||||
|
long waitSeconds() default 0L;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 锁租约秒数。
|
||||||
|
*
|
||||||
|
* @return 锁租约秒数
|
||||||
|
*/
|
||||||
|
long leaseSeconds() default 300L;
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package tech.easyflow.common.cache;
|
||||||
|
|
||||||
|
import org.aspectj.lang.ProceedingJoinPoint;
|
||||||
|
import org.aspectj.lang.annotation.Around;
|
||||||
|
import org.aspectj.lang.annotation.Aspect;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import jakarta.annotation.PreDestroy;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.concurrent.ScheduledFuture;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 定时任务分布式锁切面。
|
||||||
|
*/
|
||||||
|
@Aspect
|
||||||
|
@Component
|
||||||
|
public class DistributedScheduledLockAspect {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(DistributedScheduledLockAspect.class);
|
||||||
|
|
||||||
|
private final RedisLockExecutor redisLockExecutor;
|
||||||
|
private final ScheduledExecutorService renewExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建定时任务分布式锁切面。
|
||||||
|
*
|
||||||
|
* @param redisLockExecutor Redis 分布式锁执行器
|
||||||
|
*/
|
||||||
|
public DistributedScheduledLockAspect(RedisLockExecutor redisLockExecutor) {
|
||||||
|
this.redisLockExecutor = redisLockExecutor;
|
||||||
|
this.renewExecutor = Executors.newScheduledThreadPool(
|
||||||
|
1,
|
||||||
|
new DistributedScheduledLockThreadFactory()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 拦截带分布式调度锁的定时任务。
|
||||||
|
*
|
||||||
|
* @param joinPoint 切点
|
||||||
|
* @param lock 锁注解
|
||||||
|
* @return 原方法返回值;未抢到锁时返回 null
|
||||||
|
* @throws Throwable 原方法执行异常或 Redis 访问异常
|
||||||
|
*/
|
||||||
|
@Around("@annotation(lock)")
|
||||||
|
public Object around(ProceedingJoinPoint joinPoint, DistributedScheduledLock lock) throws Throwable {
|
||||||
|
Duration waitTimeout = Duration.ofSeconds(Math.max(lock.waitSeconds(), 0L));
|
||||||
|
Duration leaseTimeout = Duration.ofSeconds(Math.max(lock.leaseSeconds(), 1L));
|
||||||
|
RedisLockExecutor.LockHandle handle = redisLockExecutor.tryAcquire(lock.key(), waitTimeout, leaseTimeout);
|
||||||
|
if (handle == null) {
|
||||||
|
LOG.info("定时任务分布式锁已被其他实例持有,跳过本轮执行: lockKey={}, method={}",
|
||||||
|
lock.key(), joinPoint.getSignature().toShortString());
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
ScheduledFuture<?> renewTask = scheduleRenew(lock.key(), handle, leaseTimeout);
|
||||||
|
try {
|
||||||
|
return joinPoint.proceed();
|
||||||
|
} finally {
|
||||||
|
renewTask.cancel(false);
|
||||||
|
handle.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ScheduledFuture<?> scheduleRenew(String lockKey,
|
||||||
|
RedisLockExecutor.LockHandle handle,
|
||||||
|
Duration leaseTimeout) {
|
||||||
|
long renewIntervalMillis = Math.max(leaseTimeout.toMillis() / 3L, 1000L);
|
||||||
|
return renewExecutor.scheduleWithFixedDelay(() -> {
|
||||||
|
if (!handle.renew()) {
|
||||||
|
LOG.warn("定时任务分布式锁续期失败: lockKey={}", lockKey);
|
||||||
|
}
|
||||||
|
}, renewIntervalMillis, renewIntervalMillis, TimeUnit.MILLISECONDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关闭调度锁续期线程池。
|
||||||
|
*/
|
||||||
|
@PreDestroy
|
||||||
|
public void destroy() {
|
||||||
|
renewExecutor.shutdownNow();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 调度锁续期线程工厂。
|
||||||
|
*/
|
||||||
|
private static final class DistributedScheduledLockThreadFactory implements ThreadFactory {
|
||||||
|
|
||||||
|
private final AtomicInteger index = new AtomicInteger(1);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建续期线程。
|
||||||
|
*
|
||||||
|
* @param runnable 线程任务
|
||||||
|
* @return 续期线程
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Thread newThread(Runnable runnable) {
|
||||||
|
Thread thread = new Thread(runnable);
|
||||||
|
thread.setName("distributed-scheduled-lock-renew-" + index.getAndIncrement());
|
||||||
|
thread.setDaemon(true);
|
||||||
|
return thread;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,9 @@ import java.util.Collections;
|
|||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Redis 分布式锁执行器。
|
||||||
|
*/
|
||||||
@Component
|
@Component
|
||||||
public class RedisLockExecutor {
|
public class RedisLockExecutor {
|
||||||
|
|
||||||
@@ -42,6 +45,14 @@ public class RedisLockExecutor {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private StringRedisTemplate stringRedisTemplate;
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在分布式锁保护下执行无返回任务。
|
||||||
|
*
|
||||||
|
* @param lockKey 锁 key
|
||||||
|
* @param waitTimeout 等待锁的最大时间
|
||||||
|
* @param leaseTimeout 锁租约时间
|
||||||
|
* @param task 业务任务
|
||||||
|
*/
|
||||||
public void executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Runnable task) {
|
public void executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Runnable task) {
|
||||||
executeWithLock(lockKey, waitTimeout, leaseTimeout, () -> {
|
executeWithLock(lockKey, waitTimeout, leaseTimeout, () -> {
|
||||||
task.run();
|
task.run();
|
||||||
@@ -49,6 +60,16 @@ public class RedisLockExecutor {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在分布式锁保护下执行有返回任务。
|
||||||
|
*
|
||||||
|
* @param lockKey 锁 key
|
||||||
|
* @param waitTimeout 等待锁的最大时间
|
||||||
|
* @param leaseTimeout 锁租约时间
|
||||||
|
* @param task 业务任务
|
||||||
|
* @param <T> 返回类型
|
||||||
|
* @return 任务返回值
|
||||||
|
*/
|
||||||
public <T> T executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Supplier<T> task) {
|
public <T> T executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Supplier<T> task) {
|
||||||
LockHandle handle = acquire(lockKey, waitTimeout, leaseTimeout);
|
LockHandle handle = acquire(lockKey, waitTimeout, leaseTimeout);
|
||||||
try {
|
try {
|
||||||
@@ -70,24 +91,46 @@ public class RedisLockExecutor {
|
|||||||
* @return 锁句柄
|
* @return 锁句柄
|
||||||
*/
|
*/
|
||||||
public LockHandle acquire(String lockKey, Duration waitTimeout, Duration leaseTimeout) {
|
public LockHandle acquire(String lockKey, Duration waitTimeout, Duration leaseTimeout) {
|
||||||
|
LockHandle handle = tryAcquire(lockKey, waitTimeout, leaseTimeout);
|
||||||
|
if (handle == null) {
|
||||||
|
throw new IllegalStateException("获取分布式锁失败,请稍后重试,lockKey=" + lockKey);
|
||||||
|
}
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 尝试获取显式释放的分布式锁句柄。
|
||||||
|
*
|
||||||
|
* <p>返回 {@code null} 表示锁当前被其他节点持有。Redis 访问失败或等待过程被中断仍会抛出异常,
|
||||||
|
* 调用方可据此区分“正常跳过”和“基础设施异常”。</p>
|
||||||
|
*
|
||||||
|
* @param lockKey 锁 key
|
||||||
|
* @param waitTimeout 等待时间
|
||||||
|
* @param leaseTimeout 租约时间
|
||||||
|
* @return 获取成功时返回锁句柄,否则返回 null
|
||||||
|
*/
|
||||||
|
public LockHandle tryAcquire(String lockKey, Duration waitTimeout, Duration leaseTimeout) {
|
||||||
String lockValue = UUID.randomUUID().toString();
|
String lockValue = UUID.randomUUID().toString();
|
||||||
boolean acquired = false;
|
boolean acquired = false;
|
||||||
long deadline = System.nanoTime() + waitTimeout.toNanos();
|
long deadline = System.nanoTime() + waitTimeout.toNanos();
|
||||||
try {
|
try {
|
||||||
while (System.nanoTime() <= deadline) {
|
do {
|
||||||
Boolean success = stringRedisTemplate.opsForValue().setIfAbsent(lockKey, lockValue, leaseTimeout);
|
Boolean success = stringRedisTemplate.opsForValue().setIfAbsent(lockKey, lockValue, leaseTimeout);
|
||||||
if (Boolean.TRUE.equals(success)) {
|
if (Boolean.TRUE.equals(success)) {
|
||||||
acquired = true;
|
acquired = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (System.nanoTime() >= deadline) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
Thread.sleep(RETRY_INTERVAL_MILLIS);
|
Thread.sleep(RETRY_INTERVAL_MILLIS);
|
||||||
}
|
} while (System.nanoTime() <= deadline);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
throw new IllegalStateException("等待分布式锁被中断,lockKey=" + lockKey, e);
|
throw new IllegalStateException("等待分布式锁被中断,lockKey=" + lockKey, e);
|
||||||
}
|
}
|
||||||
if (!acquired) {
|
if (!acquired) {
|
||||||
throw new IllegalStateException("获取分布式锁失败,请稍后重试,lockKey=" + lockKey);
|
return null;
|
||||||
}
|
}
|
||||||
return new LockHandle(lockKey, lockValue, leaseTimeout);
|
return new LockHandle(lockKey, lockValue, leaseTimeout);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package tech.easyflow.common.cache;
|
||||||
|
|
||||||
|
import org.aspectj.lang.ProceedingJoinPoint;
|
||||||
|
import org.aspectj.lang.Signature;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.data.redis.core.ValueOperations;
|
||||||
|
import org.springframework.data.redis.core.script.RedisScript;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link DistributedScheduledLockAspect} 回归测试。
|
||||||
|
*/
|
||||||
|
public class DistributedScheduledLockAspectTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证未抢到调度锁时跳过原方法。
|
||||||
|
*
|
||||||
|
* @throws Throwable 切面执行异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void aroundShouldSkipTaskWhenLockIsHeld() throws Throwable {
|
||||||
|
RedisLockExecutor executor = createExecutor(false);
|
||||||
|
DistributedScheduledLockAspect aspect = new DistributedScheduledLockAspect(executor);
|
||||||
|
AtomicInteger proceedCount = new AtomicInteger();
|
||||||
|
|
||||||
|
Object result = aspect.around(
|
||||||
|
mockJoinPoint(proceedCount),
|
||||||
|
annotatedMethod("lockedTask").getAnnotation(DistributedScheduledLock.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertNull(result);
|
||||||
|
Assert.assertEquals(0, proceedCount.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证抢到调度锁时执行原方法并释放锁。
|
||||||
|
*
|
||||||
|
* @throws Throwable 切面执行异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void aroundShouldProceedAndReleaseWhenLockAcquired() throws Throwable {
|
||||||
|
RedisLockExecutor executor = createExecutor(true);
|
||||||
|
DistributedScheduledLockAspect aspect = new DistributedScheduledLockAspect(executor);
|
||||||
|
AtomicInteger proceedCount = new AtomicInteger();
|
||||||
|
|
||||||
|
Object result = aspect.around(
|
||||||
|
mockJoinPoint(proceedCount),
|
||||||
|
annotatedMethod("lockedTask").getAnnotation(DistributedScheduledLock.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertEquals("ok", result);
|
||||||
|
Assert.assertEquals(1, proceedCount.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
@DistributedScheduledLock(key = "easyflow:test:scheduled", leaseSeconds = 30L)
|
||||||
|
private void lockedTask() {
|
||||||
|
}
|
||||||
|
|
||||||
|
private Method annotatedMethod(String methodName) throws NoSuchMethodException {
|
||||||
|
Method method = DistributedScheduledLockAspectTest.class.getDeclaredMethod(methodName);
|
||||||
|
method.setAccessible(true);
|
||||||
|
return method;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ProceedingJoinPoint mockJoinPoint(AtomicInteger proceedCount) throws Throwable {
|
||||||
|
ProceedingJoinPoint joinPoint = Mockito.mock(ProceedingJoinPoint.class);
|
||||||
|
Signature signature = Mockito.mock(Signature.class);
|
||||||
|
Mockito.when(signature.toShortString()).thenReturn("lockedTask()");
|
||||||
|
Mockito.when(joinPoint.getSignature()).thenReturn(signature);
|
||||||
|
Mockito.when(joinPoint.proceed()).thenAnswer(invocation -> {
|
||||||
|
proceedCount.incrementAndGet();
|
||||||
|
return "ok";
|
||||||
|
});
|
||||||
|
return joinPoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private RedisLockExecutor createExecutor(boolean acquired) throws Exception {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(valueOperations.setIfAbsent(
|
||||||
|
ArgumentMatchers.anyString(),
|
||||||
|
ArgumentMatchers.anyString(),
|
||||||
|
ArgumentMatchers.any(Duration.class)
|
||||||
|
)).thenReturn(acquired);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
Mockito.when(redisTemplate.execute(
|
||||||
|
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||||
|
ArgumentMatchers.<List<String>>any(),
|
||||||
|
ArgumentMatchers.<Object[]>any()
|
||||||
|
)).thenReturn(1L);
|
||||||
|
|
||||||
|
RedisLockExecutor executor = new RedisLockExecutor();
|
||||||
|
Field field = RedisLockExecutor.class.getDeclaredField("stringRedisTemplate");
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(executor, redisTemplate);
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
package tech.easyflow.common.cache;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.data.redis.core.ValueOperations;
|
||||||
|
import org.springframework.data.redis.core.script.RedisScript;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link RedisLockExecutor} 回归测试。
|
||||||
|
*/
|
||||||
|
public class RedisLockExecutorTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证锁被占用时返回 null,便于调度任务跳过本轮执行。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void tryAcquireShouldReturnNullWhenLockIsHeld() throws Exception {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
ValueOperations<String, String> valueOperations = mockValueOperations(false);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
|
||||||
|
RedisLockExecutor executor = new RedisLockExecutor();
|
||||||
|
setRedisTemplate(executor, redisTemplate);
|
||||||
|
|
||||||
|
RedisLockExecutor.LockHandle handle = executor.tryAcquire(
|
||||||
|
"easyflow:test:lock",
|
||||||
|
Duration.ZERO,
|
||||||
|
Duration.ofSeconds(30)
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertNull(handle);
|
||||||
|
Mockito.verify(valueOperations).setIfAbsent(
|
||||||
|
ArgumentMatchers.eq("easyflow:test:lock"),
|
||||||
|
ArgumentMatchers.anyString(),
|
||||||
|
ArgumentMatchers.eq(Duration.ofSeconds(30))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证锁获取成功后释放会执行 owner token 校验脚本。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void acquiredHandleShouldReleaseLockWithOwnerToken() throws Exception {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
ValueOperations<String, String> valueOperations = mockValueOperations(true);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
Mockito.when(redisTemplate.execute(
|
||||||
|
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||||
|
ArgumentMatchers.<List<String>>any(),
|
||||||
|
ArgumentMatchers.<Object[]>any()
|
||||||
|
)).thenReturn(1L);
|
||||||
|
|
||||||
|
RedisLockExecutor executor = new RedisLockExecutor();
|
||||||
|
setRedisTemplate(executor, redisTemplate);
|
||||||
|
|
||||||
|
RedisLockExecutor.LockHandle handle = executor.tryAcquire(
|
||||||
|
"easyflow:test:lock",
|
||||||
|
Duration.ZERO,
|
||||||
|
Duration.ofSeconds(30)
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertNotNull(handle);
|
||||||
|
handle.release();
|
||||||
|
Mockito.verify(redisTemplate).execute(
|
||||||
|
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||||
|
ArgumentMatchers.eq(List.of("easyflow:test:lock")),
|
||||||
|
ArgumentMatchers.<Object[]>any()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private ValueOperations<String, String> mockValueOperations(boolean acquired) {
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(valueOperations.setIfAbsent(
|
||||||
|
ArgumentMatchers.anyString(),
|
||||||
|
ArgumentMatchers.anyString(),
|
||||||
|
ArgumentMatchers.any(Duration.class)
|
||||||
|
)).thenReturn(acquired);
|
||||||
|
return valueOperations;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setRedisTemplate(RedisLockExecutor executor, StringRedisTemplate redisTemplate) throws Exception {
|
||||||
|
Field field = RedisLockExecutor.class.getDeclaredField("stringRedisTemplate");
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(executor, redisTemplate);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -51,9 +51,22 @@ public class ChatAssistantAccumulator {
|
|||||||
* @param arguments tool 参数
|
* @param arguments tool 参数
|
||||||
*/
|
*/
|
||||||
public void appendToolCall(String id, String name, Object arguments) {
|
public void appendToolCall(String id, String name, Object arguments) {
|
||||||
|
appendToolCall(id, name, null, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记录 tool call,同时保留面向前端展示的工具名称。
|
||||||
|
*
|
||||||
|
* @param id tool call id
|
||||||
|
* @param name tool 名称
|
||||||
|
* @param displayName tool 展示名称
|
||||||
|
* @param arguments tool 参数
|
||||||
|
*/
|
||||||
|
public void appendToolCall(String id, String name, String displayName, Object arguments) {
|
||||||
Map<String, Object> chain = findToolChain(id, name);
|
Map<String, Object> chain = findToolChain(id, name);
|
||||||
chain.put("status", "TOOL_CALL");
|
chain.put("status", "TOOL_CALL");
|
||||||
chain.put("arguments", arguments);
|
chain.put("arguments", arguments);
|
||||||
|
putIfNotBlank(chain, "toolDisplayName", displayName);
|
||||||
|
|
||||||
Map<String, Object> assistantMessage = ensureToolCallAssistantMessage();
|
Map<String, Object> assistantMessage = ensureToolCallAssistantMessage();
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
@@ -63,6 +76,7 @@ public class ChatAssistantAccumulator {
|
|||||||
toolCall.put("id", id);
|
toolCall.put("id", id);
|
||||||
toolCall.put("name", name);
|
toolCall.put("name", name);
|
||||||
toolCall.put("arguments", arguments == null ? null : String.valueOf(arguments));
|
toolCall.put("arguments", arguments == null ? null : String.valueOf(arguments));
|
||||||
|
putIfNotBlank(toolCall, "toolDisplayName", displayName);
|
||||||
toolCalls.add(toolCall);
|
toolCalls.add(toolCall);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,9 +88,22 @@ public class ChatAssistantAccumulator {
|
|||||||
* @param result tool 结果
|
* @param result tool 结果
|
||||||
*/
|
*/
|
||||||
public void appendToolResult(String id, String name, Object result) {
|
public void appendToolResult(String id, String name, Object result) {
|
||||||
|
appendToolResult(id, name, null, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 记录 tool result,并保留面向前端展示的工具名称。
|
||||||
|
*
|
||||||
|
* @param id tool call id
|
||||||
|
* @param name tool 名称
|
||||||
|
* @param displayName tool 展示名称
|
||||||
|
* @param result tool 结果
|
||||||
|
*/
|
||||||
|
public void appendToolResult(String id, String name, String displayName, Object result) {
|
||||||
Map<String, Object> chain = findToolChain(id, name);
|
Map<String, Object> chain = findToolChain(id, name);
|
||||||
chain.put("status", "TOOL_RESULT");
|
chain.put("status", "TOOL_RESULT");
|
||||||
chain.put("result", result);
|
chain.put("result", result);
|
||||||
|
putIfNotBlank(chain, "toolDisplayName", displayName);
|
||||||
Map<String, Object> toolMessage = ChatRuntimeHistoryPayloadHelper.toolMessage(
|
Map<String, Object> toolMessage = ChatRuntimeHistoryPayloadHelper.toolMessage(
|
||||||
id,
|
id,
|
||||||
result == null ? null : String.valueOf(result)
|
result == null ? null : String.valueOf(result)
|
||||||
@@ -191,4 +218,10 @@ public class ChatAssistantAccumulator {
|
|||||||
private String stringValue(Object value) {
|
private String stringValue(Object value) {
|
||||||
return value == null ? null : String.valueOf(value);
|
return value == null ? null : String.valueOf(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void putIfNotBlank(Map<String, Object> target, String key, String value) {
|
||||||
|
if (value != null && !value.isBlank()) {
|
||||||
|
target.put(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,5 +22,22 @@
|
|||||||
<artifactId>jackson-databind</artifactId>
|
<artifactId>jackson-databind</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-pool2</artifactId>
|
||||||
|
<version>2.11.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<version>${junit.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>5.12.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ import org.springframework.context.annotation.Bean;
|
|||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.data.redis.connection.RedisPassword;
|
import org.springframework.data.redis.connection.RedisPassword;
|
||||||
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
|
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
|
||||||
|
import org.springframework.data.redis.connection.lettuce.LettuceClientConfiguration;
|
||||||
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
|
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
|
||||||
|
import org.springframework.data.redis.connection.lettuce.LettucePoolingClientConfiguration;
|
||||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
import tech.easyflow.common.mq.core.MQConsumerContainer;
|
import tech.easyflow.common.mq.core.MQConsumerContainer;
|
||||||
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
||||||
@@ -24,6 +26,10 @@ import tech.easyflow.common.mq.redis.RedisMQProducer;
|
|||||||
import tech.easyflow.common.mq.redis.RedisStreamKeySupport;
|
import tech.easyflow.common.mq.redis.RedisStreamKeySupport;
|
||||||
import tech.easyflow.common.mq.support.MQHealthSupport;
|
import tech.easyflow.common.mq.support.MQHealthSupport;
|
||||||
|
|
||||||
|
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
|
||||||
|
|
||||||
|
import io.lettuce.core.api.StatefulConnection;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@@ -43,11 +49,27 @@ public class MQConfiguration {
|
|||||||
if (redisProperties.getPassword() != null) {
|
if (redisProperties.getPassword() != null) {
|
||||||
configuration.setPassword(RedisPassword.of(redisProperties.getPassword()));
|
configuration.setPassword(RedisPassword.of(redisProperties.getPassword()));
|
||||||
}
|
}
|
||||||
LettuceConnectionFactory connectionFactory = new LettuceConnectionFactory(configuration);
|
LettuceClientConfiguration clientConfiguration = createClientConfiguration(redisProperties, mqProperties);
|
||||||
|
LettuceConnectionFactory connectionFactory = new LettuceConnectionFactory(configuration, clientConfiguration);
|
||||||
connectionFactory.afterPropertiesSet();
|
connectionFactory.afterPropertiesSet();
|
||||||
return new MQRedisResources(connectionFactory, new StringRedisTemplate(connectionFactory));
|
return new MQRedisResources(connectionFactory, new StringRedisTemplate(connectionFactory));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private LettuceClientConfiguration createClientConfiguration(RedisProperties redisProperties,
|
||||||
|
MQProperties mqProperties) {
|
||||||
|
MQProperties.Redis.Pool pool = mqProperties.getRedis().getPool();
|
||||||
|
GenericObjectPoolConfig<StatefulConnection<?, ?>> poolConfig = new GenericObjectPoolConfig<>();
|
||||||
|
poolConfig.setMaxTotal(pool.getMaxActive());
|
||||||
|
poolConfig.setMaxIdle(pool.getMaxIdle());
|
||||||
|
poolConfig.setMinIdle(pool.getMinIdle());
|
||||||
|
LettucePoolingClientConfiguration.LettucePoolingClientConfigurationBuilder builder =
|
||||||
|
LettucePoolingClientConfiguration.builder().poolConfig(poolConfig);
|
||||||
|
if (redisProperties.getTimeout() != null) {
|
||||||
|
builder.commandTimeout(redisProperties.getTimeout());
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
@Bean(name = "mqRedisConnectionFactory", autowireCandidate = false, defaultCandidate = false)
|
@Bean(name = "mqRedisConnectionFactory", autowireCandidate = false, defaultCandidate = false)
|
||||||
@ConditionalOnProperty(prefix = "easyflow.mq", name = "enabled", havingValue = "true", matchIfMissing = true)
|
@ConditionalOnProperty(prefix = "easyflow.mq", name = "enabled", havingValue = "true", matchIfMissing = true)
|
||||||
public LettuceConnectionFactory mqRedisConnectionFactory(MQRedisResources mqRedisResources) {
|
public LettuceConnectionFactory mqRedisConnectionFactory(MQRedisResources mqRedisResources) {
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
package tech.easyflow.common.mq.config;
|
package tech.easyflow.common.mq.config;
|
||||||
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EasyFlow MQ 配置。
|
||||||
|
*/
|
||||||
@ConfigurationProperties(prefix = "easyflow.mq")
|
@ConfigurationProperties(prefix = "easyflow.mq")
|
||||||
public class MQProperties {
|
public class MQProperties {
|
||||||
|
|
||||||
@@ -35,11 +39,14 @@ public class MQProperties {
|
|||||||
|
|
||||||
private int database = 1;
|
private int database = 1;
|
||||||
private String streamPrefix = "easyflow:mq";
|
private String streamPrefix = "easyflow:mq";
|
||||||
|
private String consumerInstanceId = defaultConsumerInstanceId();
|
||||||
private int chatPersistShardCount = 4;
|
private int chatPersistShardCount = 4;
|
||||||
private int consumerBatchSize = 200;
|
private int consumerBatchSize = 200;
|
||||||
private Duration consumerBlockTimeout = Duration.ofMillis(2000);
|
private Duration consumerBlockTimeout = Duration.ofMillis(2000);
|
||||||
private Duration pendingClaimIdle = Duration.ofMillis(60000);
|
private Duration pendingClaimIdle = Duration.ofMillis(60000);
|
||||||
private int maxRetry = 16;
|
private int maxRetry = 16;
|
||||||
|
private ConsumerExecutor consumerExecutor = new ConsumerExecutor();
|
||||||
|
private Pool pool = new Pool();
|
||||||
|
|
||||||
public int getDatabase() {
|
public int getDatabase() {
|
||||||
return database;
|
return database;
|
||||||
@@ -57,6 +64,26 @@ public class MQProperties {
|
|||||||
this.streamPrefix = streamPrefix;
|
this.streamPrefix = streamPrefix;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Redis Stream 消费实例 ID。
|
||||||
|
*
|
||||||
|
* @return 消费实例 ID
|
||||||
|
*/
|
||||||
|
public String getConsumerInstanceId() {
|
||||||
|
return consumerInstanceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Redis Stream 消费实例 ID。
|
||||||
|
*
|
||||||
|
* @param consumerInstanceId 消费实例 ID
|
||||||
|
*/
|
||||||
|
public void setConsumerInstanceId(String consumerInstanceId) {
|
||||||
|
this.consumerInstanceId = StringUtils.hasText(consumerInstanceId)
|
||||||
|
? consumerInstanceId.trim()
|
||||||
|
: defaultConsumerInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
public int getChatPersistShardCount() {
|
public int getChatPersistShardCount() {
|
||||||
return chatPersistShardCount;
|
return chatPersistShardCount;
|
||||||
}
|
}
|
||||||
@@ -96,5 +123,106 @@ public class MQProperties {
|
|||||||
public void setMaxRetry(int maxRetry) {
|
public void setMaxRetry(int maxRetry) {
|
||||||
this.maxRetry = maxRetry;
|
this.maxRetry = maxRetry;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ConsumerExecutor getConsumerExecutor() {
|
||||||
|
return consumerExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setConsumerExecutor(ConsumerExecutor consumerExecutor) {
|
||||||
|
this.consumerExecutor = consumerExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Pool getPool() {
|
||||||
|
return pool;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setPool(Pool pool) {
|
||||||
|
this.pool = pool;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Redis MQ 消费线程池配置。
|
||||||
|
*/
|
||||||
|
public static class ConsumerExecutor {
|
||||||
|
|
||||||
|
private int coreSize = 4;
|
||||||
|
private int maxSize = 12;
|
||||||
|
private int queueCapacity = 64;
|
||||||
|
private int keepAliveSeconds = 60;
|
||||||
|
|
||||||
|
public int getCoreSize() {
|
||||||
|
return coreSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCoreSize(int coreSize) {
|
||||||
|
this.coreSize = coreSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxSize() {
|
||||||
|
return maxSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxSize(int maxSize) {
|
||||||
|
this.maxSize = maxSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getQueueCapacity() {
|
||||||
|
return queueCapacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setQueueCapacity(int queueCapacity) {
|
||||||
|
this.queueCapacity = queueCapacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getKeepAliveSeconds() {
|
||||||
|
return keepAliveSeconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setKeepAliveSeconds(int keepAliveSeconds) {
|
||||||
|
this.keepAliveSeconds = keepAliveSeconds;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Redis MQ 连接池配置。
|
||||||
|
*/
|
||||||
|
public static class Pool {
|
||||||
|
|
||||||
|
private int maxActive = 12;
|
||||||
|
private int maxIdle = 8;
|
||||||
|
private int minIdle = 1;
|
||||||
|
|
||||||
|
public int getMaxActive() {
|
||||||
|
return maxActive;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxActive(int maxActive) {
|
||||||
|
this.maxActive = maxActive;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxIdle() {
|
||||||
|
return maxIdle;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxIdle(int maxIdle) {
|
||||||
|
this.maxIdle = maxIdle;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMinIdle() {
|
||||||
|
return minIdle;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMinIdle(int minIdle) {
|
||||||
|
this.minIdle = minIdle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String defaultConsumerInstanceId() {
|
||||||
|
String hostName = System.getenv("HOSTNAME");
|
||||||
|
if (StringUtils.hasText(hostName)) {
|
||||||
|
return hostName.trim();
|
||||||
|
}
|
||||||
|
return java.util.UUID.randomUUID().toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ public class MQSubscription {
|
|||||||
private String topic;
|
private String topic;
|
||||||
private String consumerGroup;
|
private String consumerGroup;
|
||||||
private int shardCount;
|
private int shardCount;
|
||||||
|
private boolean batchEnabled = true;
|
||||||
|
|
||||||
public String getTopic() {
|
public String getTopic() {
|
||||||
return topic;
|
return topic;
|
||||||
@@ -29,4 +30,22 @@ public class MQSubscription {
|
|||||||
public void setShardCount(int shardCount) {
|
public void setShardCount(int shardCount) {
|
||||||
this.shardCount = shardCount;
|
this.shardCount = shardCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用批量消费。
|
||||||
|
*
|
||||||
|
* @return true 表示启用批量消费
|
||||||
|
*/
|
||||||
|
public boolean isBatchEnabled() {
|
||||||
|
return batchEnabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置是否启用批量消费。
|
||||||
|
*
|
||||||
|
* @param batchEnabled 是否启用批量消费
|
||||||
|
*/
|
||||||
|
public void setBatchEnabled(boolean batchEnabled) {
|
||||||
|
this.batchEnabled = batchEnabled;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,13 +30,17 @@ import java.util.ArrayList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
import java.util.concurrent.ArrayBlockingQueue;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.ThreadPoolExecutor;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifecycle {
|
public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifecycle {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(RedisMQConsumerContainer.class);
|
private static final Logger LOG = LoggerFactory.getLogger(RedisMQConsumerContainer.class);
|
||||||
|
private static final Pattern UNSAFE_CONSUMER_NAME_CHARS = Pattern.compile("[^A-Za-z0-9_.-]");
|
||||||
|
|
||||||
private final RedisConnectionFactory redisConnectionFactory;
|
private final RedisConnectionFactory redisConnectionFactory;
|
||||||
private final StringRedisTemplate stringRedisTemplate;
|
private final StringRedisTemplate stringRedisTemplate;
|
||||||
@@ -45,7 +49,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
private final MQDeadLetterService deadLetterService;
|
private final MQDeadLetterService deadLetterService;
|
||||||
private final RedisStreamKeySupport keySupport;
|
private final RedisStreamKeySupport keySupport;
|
||||||
private final List<MQConsumerHandler> handlers;
|
private final List<MQConsumerHandler> handlers;
|
||||||
private final ExecutorService executorService = Executors.newCachedThreadPool();
|
private final ExecutorService executorService;
|
||||||
|
|
||||||
private volatile boolean running;
|
private volatile boolean running;
|
||||||
|
|
||||||
@@ -63,6 +67,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
this.deadLetterService = deadLetterService;
|
this.deadLetterService = deadLetterService;
|
||||||
this.keySupport = keySupport;
|
this.keySupport = keySupport;
|
||||||
this.handlers = handlers;
|
this.handlers = handlers;
|
||||||
|
this.executorService = createExecutor(properties, handlers);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -77,7 +82,12 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
int currentShard = shard;
|
int currentShard = shard;
|
||||||
LOG.info("启动 MQ 消费线程: topic={}, group={}, shard={}, handler={}",
|
LOG.info("启动 MQ 消费线程: topic={}, group={}, shard={}, handler={}",
|
||||||
subscription.getTopic(), subscription.getConsumerGroup(), currentShard, handler.getClass().getSimpleName());
|
subscription.getTopic(), subscription.getConsumerGroup(), currentShard, handler.getClass().getSimpleName());
|
||||||
executorService.submit(() -> consumeLoop(handler, subscription, currentShard));
|
try {
|
||||||
|
executorService.submit(() -> consumeLoop(handler, subscription, currentShard));
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
running = false;
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -108,15 +118,62 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
stop();
|
stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private ExecutorService createExecutor(MQProperties properties, List<MQConsumerHandler> handlers) {
|
||||||
|
MQProperties.Redis.ConsumerExecutor config = properties.getRedis().getConsumerExecutor();
|
||||||
|
int consumerTaskCount = handlers.stream()
|
||||||
|
.map(MQConsumerHandler::subscription)
|
||||||
|
.filter(Objects::nonNull)
|
||||||
|
.mapToInt(subscription -> Math.max(subscription.getShardCount(), 1))
|
||||||
|
.sum();
|
||||||
|
if (config.getCoreSize() > config.getMaxSize()) {
|
||||||
|
throw new IllegalStateException("Redis MQ 消费线程池配置错误:core-size 不能大于 max-size");
|
||||||
|
}
|
||||||
|
if (consumerTaskCount > config.getMaxSize()) {
|
||||||
|
throw new IllegalStateException("Redis MQ 消费线程池配置错误:max-size="
|
||||||
|
+ config.getMaxSize() + " 小于消费循环数 " + consumerTaskCount
|
||||||
|
+ ",请调大 easyflow.mq.redis.consumer-executor.max-size");
|
||||||
|
}
|
||||||
|
int coreSize = Math.max(config.getCoreSize(), consumerTaskCount);
|
||||||
|
int maxSize = config.getMaxSize();
|
||||||
|
AtomicInteger threadIndex = new AtomicInteger(1);
|
||||||
|
ThreadPoolExecutor executor = new ThreadPoolExecutor(
|
||||||
|
coreSize,
|
||||||
|
maxSize,
|
||||||
|
config.getKeepAliveSeconds(),
|
||||||
|
TimeUnit.SECONDS,
|
||||||
|
new ArrayBlockingQueue<>(config.getQueueCapacity()),
|
||||||
|
task -> {
|
||||||
|
Thread thread = new Thread(task);
|
||||||
|
thread.setName("redis-mq-consumer-" + threadIndex.getAndIncrement());
|
||||||
|
thread.setDaemon(false);
|
||||||
|
return thread;
|
||||||
|
},
|
||||||
|
new ThreadPoolExecutor.AbortPolicy()
|
||||||
|
);
|
||||||
|
executor.allowCoreThreadTimeOut(true);
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
private void consumeLoop(MQConsumerHandler handler, MQSubscription subscription, int shard) {
|
private void consumeLoop(MQConsumerHandler handler, MQSubscription subscription, int shard) {
|
||||||
String streamKey = keySupport.streamKey(subscription.getTopic(), shard);
|
String streamKey = keySupport.streamKey(subscription.getTopic(), shard);
|
||||||
String consumerName = subscription.getConsumerGroup() + "-" + shard;
|
String consumerName = buildConsumerName(subscription.getConsumerGroup(), shard);
|
||||||
ensureConsumerGroup(streamKey, subscription.getConsumerGroup());
|
ensureConsumerGroup(streamKey, subscription.getConsumerGroup());
|
||||||
LOG.info("MQ 消费循环已启动: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
LOG.info("MQ 消费循环已启动: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
||||||
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, handler.getClass().getSimpleName());
|
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, handler.getClass().getSimpleName());
|
||||||
while (running) {
|
while (running) {
|
||||||
try {
|
try {
|
||||||
reclaimPending(streamKey, subscription.getConsumerGroup(), consumerName);
|
List<MapRecord<String, Object, Object>> pendingRecords =
|
||||||
|
reclaimPending(streamKey, subscription.getConsumerGroup(), consumerName);
|
||||||
|
if (!pendingRecords.isEmpty()) {
|
||||||
|
List<MQMessage> pendingMessages = toMessages(streamKey, pendingRecords);
|
||||||
|
if (!pendingMessages.isEmpty()) {
|
||||||
|
LOG.info("MQ 收到重领 pending 消息批次: topic={}, group={}, shard={}, consumer={}, streamKey={}, count={}",
|
||||||
|
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName,
|
||||||
|
streamKey, pendingMessages.size());
|
||||||
|
handleMessages(handler, subscription, streamKey, subscription.getConsumerGroup(), pendingMessages);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
List<MapRecord<String, Object, Object>> records = stringRedisTemplate.opsForStream().read(
|
List<MapRecord<String, Object, Object>> records = stringRedisTemplate.opsForStream().read(
|
||||||
Consumer.from(subscription.getConsumerGroup(), consumerName),
|
Consumer.from(subscription.getConsumerGroup(), consumerName),
|
||||||
StreamReadOptions.empty()
|
StreamReadOptions.empty()
|
||||||
@@ -133,7 +190,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
LOG.info("MQ 收到消息批次: topic={}, group={}, shard={}, consumer={}, streamKey={}, count={}",
|
LOG.info("MQ 收到消息批次: topic={}, group={}, shard={}, consumer={}, streamKey={}, count={}",
|
||||||
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, messages.size());
|
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, messages.size());
|
||||||
handleMessages(handler, streamKey, subscription.getConsumerGroup(), messages);
|
handleMessages(handler, subscription, streamKey, subscription.getConsumerGroup(), messages);
|
||||||
} catch (Exception exception) {
|
} catch (Exception exception) {
|
||||||
LOG.error("MQ 消费循环异常: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
LOG.error("MQ 消费循环异常: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
||||||
subscription.getTopic(),
|
subscription.getTopic(),
|
||||||
@@ -148,7 +205,20 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void reclaimPending(String streamKey, String group, String consumerName) {
|
/**
|
||||||
|
* 构建 Redis Stream consumer name。
|
||||||
|
*
|
||||||
|
* @param consumerGroup 消费组
|
||||||
|
* @param shard 分片序号
|
||||||
|
* @return consumer name
|
||||||
|
*/
|
||||||
|
String buildConsumerName(String consumerGroup, int shard) {
|
||||||
|
String instanceId = properties.getRedis().getConsumerInstanceId();
|
||||||
|
String safeInstanceId = UNSAFE_CONSUMER_NAME_CHARS.matcher(instanceId).replaceAll("-");
|
||||||
|
return consumerGroup + "-" + shard + "-" + safeInstanceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<MapRecord<String, Object, Object>> reclaimPending(String streamKey, String group, String consumerName) {
|
||||||
Duration idle = properties.getRedis().getPendingClaimIdle();
|
Duration idle = properties.getRedis().getPendingClaimIdle();
|
||||||
try (RedisConnection connection = redisConnectionFactory.getConnection()) {
|
try (RedisConnection connection = redisConnectionFactory.getConnection()) {
|
||||||
RedisStreamCommands.XPendingOptions options = RedisStreamCommands.XPendingOptions
|
RedisStreamCommands.XPendingOptions options = RedisStreamCommands.XPendingOptions
|
||||||
@@ -156,7 +226,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
var pendingMessages = connection.streamCommands()
|
var pendingMessages = connection.streamCommands()
|
||||||
.xPending(streamKey.getBytes(StandardCharsets.UTF_8), group, options);
|
.xPending(streamKey.getBytes(StandardCharsets.UTF_8), group, options);
|
||||||
if (pendingMessages == null || pendingMessages.isEmpty()) {
|
if (pendingMessages == null || pendingMessages.isEmpty()) {
|
||||||
return;
|
return List.of();
|
||||||
}
|
}
|
||||||
List<RecordId> ids = new ArrayList<>();
|
List<RecordId> ids = new ArrayList<>();
|
||||||
for (PendingMessage pendingMessage : pendingMessages) {
|
for (PendingMessage pendingMessage : pendingMessages) {
|
||||||
@@ -165,15 +235,16 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ids.isEmpty()) {
|
if (ids.isEmpty()) {
|
||||||
return;
|
return List.of();
|
||||||
}
|
}
|
||||||
stringRedisTemplate.opsForStream().claim(
|
List<MapRecord<String, Object, Object>> records = stringRedisTemplate.opsForStream().claim(
|
||||||
streamKey,
|
streamKey,
|
||||||
group,
|
group,
|
||||||
consumerName,
|
consumerName,
|
||||||
idle,
|
idle,
|
||||||
ids.toArray(new RecordId[0])
|
ids.toArray(new RecordId[0])
|
||||||
);
|
);
|
||||||
|
return records == null ? List.of() : records;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,7 +260,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<MQMessage> toMessages(String streamKey, List<MapRecord<String, Object, Object>> records) {
|
List<MQMessage> toMessages(String streamKey, List<MapRecord<String, Object, Object>> records) {
|
||||||
List<MQMessage> messages = new ArrayList<>(records.size());
|
List<MQMessage> messages = new ArrayList<>(records.size());
|
||||||
for (MapRecord<String, Object, Object> record : records) {
|
for (MapRecord<String, Object, Object> record : records) {
|
||||||
Object payload = record.getValue().get("payload");
|
Object payload = record.getValue().get("payload");
|
||||||
@@ -225,7 +296,15 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleMessages(MQConsumerHandler handler, String streamKey, String group, List<MQMessage> messages) throws Exception {
|
void handleMessages(MQConsumerHandler handler,
|
||||||
|
MQSubscription subscription,
|
||||||
|
String streamKey,
|
||||||
|
String group,
|
||||||
|
List<MQMessage> messages) throws Exception {
|
||||||
|
if (!subscription.isBatchEnabled()) {
|
||||||
|
handleMessagesIndividually(handler, streamKey, group, messages);
|
||||||
|
return;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
LOG.info("MQ 开始批量处理消息: group={}, streamKey={}, count={}, handler={}",
|
LOG.info("MQ 开始批量处理消息: group={}, streamKey={}, count={}, handler={}",
|
||||||
group, streamKey, messages.size(), handler.getClass().getSimpleName());
|
group, streamKey, messages.size(), handler.getClass().getSimpleName());
|
||||||
@@ -244,6 +323,13 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handleMessagesIndividually(handler, streamKey, group, messages);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleMessagesIndividually(MQConsumerHandler handler,
|
||||||
|
String streamKey,
|
||||||
|
String group,
|
||||||
|
List<MQMessage> messages) {
|
||||||
for (MQMessage message : messages) {
|
for (MQMessage message : messages) {
|
||||||
try {
|
try {
|
||||||
LOG.info("MQ 开始单条处理消息: group={}, streamKey={}, messageId={}, handler={}",
|
LOG.info("MQ 开始单条处理消息: group={}, streamKey={}, messageId={}, handler={}",
|
||||||
|
|||||||
@@ -0,0 +1,175 @@
|
|||||||
|
package tech.easyflow.common.mq.redis;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.connection.RedisConnection;
|
||||||
|
import org.springframework.data.redis.connection.RedisConnectionFactory;
|
||||||
|
import org.springframework.data.redis.connection.RedisStreamCommands;
|
||||||
|
import org.springframework.data.redis.connection.stream.Consumer;
|
||||||
|
import org.springframework.data.redis.connection.stream.MapRecord;
|
||||||
|
import org.springframework.data.redis.connection.stream.PendingMessage;
|
||||||
|
import org.springframework.data.redis.connection.stream.PendingMessages;
|
||||||
|
import org.springframework.data.redis.connection.stream.RecordId;
|
||||||
|
import org.springframework.data.redis.core.StreamOperations;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import tech.easyflow.common.mq.config.MQProperties;
|
||||||
|
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
||||||
|
import tech.easyflow.common.mq.core.MQDeadLetterService;
|
||||||
|
import tech.easyflow.common.mq.core.MQMessage;
|
||||||
|
import tech.easyflow.common.mq.core.MQMessageConverter;
|
||||||
|
import tech.easyflow.common.mq.core.MQSubscription;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link RedisMQConsumerContainer} 回归测试。
|
||||||
|
*/
|
||||||
|
public class RedisMQConsumerContainerTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 consumer name 包含稳定实例 ID,且消费组名称不被改变。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void buildConsumerNameShouldAppendSanitizedInstanceId() {
|
||||||
|
MQProperties properties = new MQProperties();
|
||||||
|
properties.getRedis().setConsumerInstanceId("node/a:1");
|
||||||
|
RedisMQConsumerContainer container = new RedisMQConsumerContainer(
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
properties,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
List.of()
|
||||||
|
);
|
||||||
|
|
||||||
|
String consumerName = container.buildConsumerName("chat-persist", 2);
|
||||||
|
|
||||||
|
Assert.assertEquals("chat-persist-2-node-a-1", consumerName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证关闭批量消费后,容器按单条处理并独立确认消息。
|
||||||
|
*
|
||||||
|
* @throws Exception 消息处理异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void handleMessagesShouldProcessIndividuallyWhenBatchDisabled() throws Exception {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
StreamOperations<String, Object, Object> streamOperations = Mockito.mock(StreamOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForStream()).thenReturn(streamOperations);
|
||||||
|
RecordingHandler handler = new RecordingHandler();
|
||||||
|
MQSubscription subscription = new MQSubscription();
|
||||||
|
subscription.setBatchEnabled(false);
|
||||||
|
RedisMQConsumerContainer container = container(redisTemplate, null);
|
||||||
|
MQMessage first = message("message-1", "1-0");
|
||||||
|
MQMessage second = message("message-2", "2-0");
|
||||||
|
|
||||||
|
container.handleMessages(handler, subscription, "stream-1", "group-1", List.of(first, second));
|
||||||
|
|
||||||
|
Assert.assertEquals(List.of(List.of("message-1"), List.of("message-2")), handler.calls);
|
||||||
|
Mockito.verify(streamOperations).acknowledge("stream-1", "group-1", "1-0");
|
||||||
|
Mockito.verify(streamOperations).acknowledge("stream-1", "group-1", "2-0");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 pending 消息被 claim 后可以转换为 MQ 消息继续消费。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void reclaimPendingShouldReturnClaimedRecordsForConsumption() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
StreamOperations<String, Object, Object> streamOperations = Mockito.mock(StreamOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForStream()).thenReturn(streamOperations);
|
||||||
|
RedisConnectionFactory connectionFactory = Mockito.mock(RedisConnectionFactory.class);
|
||||||
|
RedisConnection connection = Mockito.mock(RedisConnection.class);
|
||||||
|
RedisStreamCommands streamCommands = Mockito.mock(RedisStreamCommands.class);
|
||||||
|
Mockito.when(connectionFactory.getConnection()).thenReturn(connection);
|
||||||
|
Mockito.when(connection.streamCommands()).thenReturn(streamCommands);
|
||||||
|
PendingMessage pendingMessage = new PendingMessage(
|
||||||
|
RecordId.of("1-0"), Consumer.from("group-1", "old-consumer"), Duration.ofMinutes(2), 1);
|
||||||
|
Mockito.when(streamCommands.xPending(
|
||||||
|
ArgumentMatchers.eq("stream-1".getBytes(java.nio.charset.StandardCharsets.UTF_8)),
|
||||||
|
ArgumentMatchers.eq("group-1"),
|
||||||
|
ArgumentMatchers.any(RedisStreamCommands.XPendingOptions.class)))
|
||||||
|
.thenReturn(new PendingMessages("group-1", List.of(pendingMessage)));
|
||||||
|
Map<Object, Object> payload = Map.of("payload", "message-1");
|
||||||
|
MapRecord<String, Object, Object> record = MapRecord
|
||||||
|
.create("stream-1", payload)
|
||||||
|
.withId(RecordId.of("1-0"));
|
||||||
|
Mockito.when(streamOperations.claim(
|
||||||
|
ArgumentMatchers.eq("stream-1"),
|
||||||
|
ArgumentMatchers.eq("group-1"),
|
||||||
|
ArgumentMatchers.eq("consumer-1"),
|
||||||
|
ArgumentMatchers.any(Duration.class),
|
||||||
|
ArgumentMatchers.any(RecordId[].class)))
|
||||||
|
.thenReturn(List.of(record));
|
||||||
|
RedisMQConsumerContainer container = container(redisTemplate, connectionFactory);
|
||||||
|
|
||||||
|
List<MapRecord<String, Object, Object>> records =
|
||||||
|
container.reclaimPending("stream-1", "group-1", "consumer-1");
|
||||||
|
List<MQMessage> messages = container.toMessages("stream-1", records);
|
||||||
|
|
||||||
|
Assert.assertEquals(1, records.size());
|
||||||
|
Assert.assertEquals(1, messages.size());
|
||||||
|
Assert.assertEquals("message-1", messages.get(0).getMessageId());
|
||||||
|
Assert.assertEquals("1-0", messages.get(0).getStreamMessageId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private RedisMQConsumerContainer container(StringRedisTemplate redisTemplate,
|
||||||
|
RedisConnectionFactory connectionFactory) {
|
||||||
|
MQProperties properties = new MQProperties();
|
||||||
|
return new RedisMQConsumerContainer(
|
||||||
|
connectionFactory,
|
||||||
|
redisTemplate,
|
||||||
|
properties,
|
||||||
|
new PlainMessageConverter(),
|
||||||
|
Mockito.mock(MQDeadLetterService.class),
|
||||||
|
null,
|
||||||
|
List.of()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private MQMessage message(String messageId, String streamMessageId) {
|
||||||
|
MQMessage message = new MQMessage();
|
||||||
|
message.setMessageId(messageId);
|
||||||
|
message.setStreamMessageId(streamMessageId);
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class RecordingHandler implements MQConsumerHandler {
|
||||||
|
|
||||||
|
private final List<List<String>> calls = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MQSubscription subscription() {
|
||||||
|
return new MQSubscription();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handle(List<MQMessage> messages) {
|
||||||
|
calls.add(messages.stream().map(MQMessage::getMessageId).toList());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class PlainMessageConverter implements MQMessageConverter {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String serialize(MQMessage message) {
|
||||||
|
return message.getMessageId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MQMessage deserialize(String payload) {
|
||||||
|
MQMessage message = new MQMessage();
|
||||||
|
message.setMessageId(payload);
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,10 @@
|
|||||||
<groupId>tech.easyflow</groupId>
|
<groupId>tech.easyflow</groupId>
|
||||||
<artifactId>easyflow-common-cache</artifactId>
|
<artifactId>easyflow-common-cache</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>tech.easyflow</groupId>
|
||||||
|
<artifactId>easyflow-common-mq</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>tech.easyflow</groupId>
|
<groupId>tech.easyflow</groupId>
|
||||||
<artifactId>easyflow-common-web</artifactId>
|
<artifactId>easyflow-common-web</artifactId>
|
||||||
@@ -63,5 +67,11 @@
|
|||||||
<version>${junit.version}</version>
|
<version>${junit.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>5.12.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package tech.easyflow.agent.config;
|
|||||||
import org.mybatis.spring.annotation.MapperScan;
|
import org.mybatis.spring.annotation.MapperScan;
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
|
import org.springframework.context.annotation.ComponentScan;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent 模块自动配置。
|
* Agent 模块自动配置。
|
||||||
*/
|
*/
|
||||||
@AutoConfiguration
|
@AutoConfiguration
|
||||||
@MapperScan("tech.easyflow.agent.mapper")
|
@MapperScan("tech.easyflow.agent.mapper")
|
||||||
|
@ComponentScan("tech.easyflow.agent")
|
||||||
@EnableConfigurationProperties(AgentRuntimeProperties.class)
|
@EnableConfigurationProperties(AgentRuntimeProperties.class)
|
||||||
public class AgentModuleConfig {
|
public class AgentModuleConfig {
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package tech.easyflow.agent.config;
|
package tech.easyflow.agent.config;
|
||||||
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent 运行态生产化配置。
|
* Agent 运行态生产化配置。
|
||||||
@@ -15,6 +17,36 @@ public class AgentRuntimeProperties {
|
|||||||
*/
|
*/
|
||||||
private Duration sessionCacheTtl = Duration.ofHours(24);
|
private Duration sessionCacheTtl = Duration.ofHours(24);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 当前 Agent 运行实例 ID。
|
||||||
|
*/
|
||||||
|
private String instanceId = defaultInstanceId();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行路由 TTL。
|
||||||
|
*/
|
||||||
|
private Duration routeTtl = Duration.ofHours(24);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行命令 topic 前缀。
|
||||||
|
*/
|
||||||
|
private String commandTopicPrefix = "easyflow:agent-runtime-command";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行命令结果等待超时时间。
|
||||||
|
*/
|
||||||
|
private Duration commandResultTimeout = Duration.ofSeconds(5);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行命令结果缓存 TTL。
|
||||||
|
*/
|
||||||
|
private Duration commandResultTtl = Duration.ofMinutes(5);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 当前进程启动代 ID。
|
||||||
|
*/
|
||||||
|
private final String bootId = UUID.randomUUID().toString();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* HITL pending 默认过期时间。
|
* HITL pending 默认过期时间。
|
||||||
*/
|
*/
|
||||||
@@ -35,6 +67,11 @@ public class AgentRuntimeProperties {
|
|||||||
*/
|
*/
|
||||||
private Duration lockRenewInterval = Duration.ofMinutes(1);
|
private Duration lockRenewInterval = Duration.ofMinutes(1);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 异步工具任务 Redis 运行态 TTL。
|
||||||
|
*/
|
||||||
|
private Duration asyncToolTaskTtl = Duration.ofHours(24);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 Redis 热态 session 缓存 TTL。
|
* 获取 Redis 热态 session 缓存 TTL。
|
||||||
*
|
*
|
||||||
@@ -53,6 +90,107 @@ public class AgentRuntimeProperties {
|
|||||||
this.sessionCacheTtl = sessionCacheTtl == null ? Duration.ofHours(24) : sessionCacheTtl;
|
this.sessionCacheTtl = sessionCacheTtl == null ? Duration.ofHours(24) : sessionCacheTtl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前 Agent 运行实例 ID。
|
||||||
|
*
|
||||||
|
* @return 实例 ID
|
||||||
|
*/
|
||||||
|
public String getInstanceId() {
|
||||||
|
return instanceId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置当前 Agent 运行实例 ID。
|
||||||
|
*
|
||||||
|
* @param instanceId 实例 ID
|
||||||
|
*/
|
||||||
|
public void setInstanceId(String instanceId) {
|
||||||
|
this.instanceId = StringUtils.hasText(instanceId) ? instanceId.trim() : defaultInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 运行路由 TTL。
|
||||||
|
*
|
||||||
|
* @return 路由 TTL
|
||||||
|
*/
|
||||||
|
public Duration getRouteTtl() {
|
||||||
|
return routeTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 运行路由 TTL。
|
||||||
|
*
|
||||||
|
* @param routeTtl 路由 TTL
|
||||||
|
*/
|
||||||
|
public void setRouteTtl(Duration routeTtl) {
|
||||||
|
this.routeTtl = routeTtl == null ? Duration.ofHours(24) : routeTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 运行命令 topic 前缀。
|
||||||
|
*
|
||||||
|
* @return 命令 topic 前缀
|
||||||
|
*/
|
||||||
|
public String getCommandTopicPrefix() {
|
||||||
|
return commandTopicPrefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 运行命令 topic 前缀。
|
||||||
|
*
|
||||||
|
* @param commandTopicPrefix 命令 topic 前缀
|
||||||
|
*/
|
||||||
|
public void setCommandTopicPrefix(String commandTopicPrefix) {
|
||||||
|
this.commandTopicPrefix = StringUtils.hasText(commandTopicPrefix)
|
||||||
|
? commandTopicPrefix.trim()
|
||||||
|
: "easyflow:agent-runtime-command";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 运行命令结果等待超时时间。
|
||||||
|
*
|
||||||
|
* @return 等待超时时间
|
||||||
|
*/
|
||||||
|
public Duration getCommandResultTimeout() {
|
||||||
|
return commandResultTimeout;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 运行命令结果等待超时时间。
|
||||||
|
*
|
||||||
|
* @param commandResultTimeout 等待超时时间
|
||||||
|
*/
|
||||||
|
public void setCommandResultTimeout(Duration commandResultTimeout) {
|
||||||
|
this.commandResultTimeout = commandResultTimeout == null ? Duration.ofSeconds(5) : commandResultTimeout;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 运行命令结果缓存 TTL。
|
||||||
|
*
|
||||||
|
* @return 结果缓存 TTL
|
||||||
|
*/
|
||||||
|
public Duration getCommandResultTtl() {
|
||||||
|
return commandResultTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 运行命令结果缓存 TTL。
|
||||||
|
*
|
||||||
|
* @param commandResultTtl 结果缓存 TTL
|
||||||
|
*/
|
||||||
|
public void setCommandResultTtl(Duration commandResultTtl) {
|
||||||
|
this.commandResultTtl = commandResultTtl == null ? Duration.ofMinutes(5) : commandResultTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前进程启动代 ID。
|
||||||
|
*
|
||||||
|
* @return 启动代 ID
|
||||||
|
*/
|
||||||
|
public String getBootId() {
|
||||||
|
return bootId;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 HITL pending 默认过期时间。
|
* 获取 HITL pending 默认过期时间。
|
||||||
*
|
*
|
||||||
@@ -124,4 +262,34 @@ public class AgentRuntimeProperties {
|
|||||||
public void setLockRenewInterval(Duration lockRenewInterval) {
|
public void setLockRenewInterval(Duration lockRenewInterval) {
|
||||||
this.lockRenewInterval = lockRenewInterval == null ? Duration.ofMinutes(1) : lockRenewInterval;
|
this.lockRenewInterval = lockRenewInterval == null ? Duration.ofMinutes(1) : lockRenewInterval;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 异步工具任务 Redis 运行态 TTL。
|
||||||
|
*
|
||||||
|
* @return 任务 TTL
|
||||||
|
*/
|
||||||
|
public Duration getAsyncToolTaskTtl() {
|
||||||
|
return asyncToolTaskTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 异步工具任务 Redis 运行态 TTL。
|
||||||
|
*
|
||||||
|
* @param asyncToolTaskTtl 任务 TTL
|
||||||
|
*/
|
||||||
|
public void setAsyncToolTaskTtl(Duration asyncToolTaskTtl) {
|
||||||
|
this.asyncToolTaskTtl = asyncToolTaskTtl == null ? Duration.ofHours(24) : asyncToolTaskTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String defaultInstanceId() {
|
||||||
|
String envInstanceId = System.getenv("EASYFLOW_INSTANCE_ID");
|
||||||
|
if (StringUtils.hasText(envInstanceId)) {
|
||||||
|
return envInstanceId.trim();
|
||||||
|
}
|
||||||
|
String hostName = System.getenv("HOSTNAME");
|
||||||
|
if (StringUtils.hasText(hostName)) {
|
||||||
|
return hostName.trim();
|
||||||
|
}
|
||||||
|
return UUID.randomUUID().toString();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程命令动作。
|
||||||
|
*/
|
||||||
|
public enum AgentRuntimeCommandAction {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批准工具执行。
|
||||||
|
*/
|
||||||
|
APPROVE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 拒绝工具执行。
|
||||||
|
*/
|
||||||
|
REJECT
|
||||||
|
}
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.agent.runtime.AgentRunService;
|
||||||
|
import tech.easyflow.common.mq.config.MQProperties;
|
||||||
|
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
||||||
|
import tech.easyflow.common.mq.core.MQMessage;
|
||||||
|
import tech.easyflow.common.mq.core.MQSubscription;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程命令消费者。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class AgentRuntimeCommandConsumer implements MQConsumerHandler {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeCommandConsumer.class);
|
||||||
|
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
private final MQProperties mqProperties;
|
||||||
|
private final AgentRunService agentRunService;
|
||||||
|
private final AgentRuntimeCommandResultRegistry resultRegistry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 运行态远程命令消费者。
|
||||||
|
*
|
||||||
|
* @param objectMapper JSON 序列化器
|
||||||
|
* @param properties Agent 运行配置
|
||||||
|
* @param mqProperties MQ 配置
|
||||||
|
* @param agentRunService Agent 运行服务
|
||||||
|
* @param resultRegistry 远程命令结果注册表
|
||||||
|
*/
|
||||||
|
public AgentRuntimeCommandConsumer(ObjectMapper objectMapper,
|
||||||
|
AgentRuntimeProperties properties,
|
||||||
|
MQProperties mqProperties,
|
||||||
|
AgentRunService agentRunService,
|
||||||
|
AgentRuntimeCommandResultRegistry resultRegistry) {
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
this.properties = properties;
|
||||||
|
this.mqProperties = mqProperties;
|
||||||
|
this.agentRunService = agentRunService;
|
||||||
|
this.resultRegistry = resultRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MQSubscription subscription() {
|
||||||
|
MQSubscription subscription = new MQSubscription();
|
||||||
|
subscription.setTopic(commandTopic());
|
||||||
|
subscription.setConsumerGroup(commandTopic());
|
||||||
|
subscription.setShardCount(Math.max(mqProperties.getRedis().getChatPersistShardCount(), 1));
|
||||||
|
subscription.setBatchEnabled(false);
|
||||||
|
return subscription;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handle(List<MQMessage> messages) {
|
||||||
|
if (messages == null || messages.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (MQMessage message : messages) {
|
||||||
|
try {
|
||||||
|
handleCommand(message, objectMapper.readValue(message.getBody(), AgentRuntimeCommandMessage.class));
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.warn("Agent 远程运行命令解析失败: messageId={}", message.getMessageId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleCommand(MQMessage message, AgentRuntimeCommandMessage command) {
|
||||||
|
if (command == null || command.getAction() == null) {
|
||||||
|
LOG.warn("跳过非法 Agent 远程运行命令: messageId={}", message.getMessageId());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!properties.getInstanceId().equals(command.getTargetNodeId())) {
|
||||||
|
LOG.warn("跳过非本节点 Agent 远程运行命令: messageId={}, targetNodeId={}, currentNodeId={}",
|
||||||
|
message.getMessageId(), command.getTargetNodeId(), properties.getInstanceId());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
if (command.getAction() == AgentRuntimeCommandAction.APPROVE) {
|
||||||
|
agentRunService.approveRuntimeLocal(
|
||||||
|
command.getRequestId(), command.getResumeToken(), command.getOperatorId(), command.getUserId());
|
||||||
|
} else if (command.getAction() == AgentRuntimeCommandAction.REJECT) {
|
||||||
|
agentRunService.rejectRuntimeLocal(
|
||||||
|
command.getRequestId(), command.getResumeToken(), command.getReason(),
|
||||||
|
command.getOperatorId(), command.getUserId());
|
||||||
|
} else {
|
||||||
|
markFailureQuietly(command, new IllegalArgumentException("不支持的 Agent 远程运行命令"));
|
||||||
|
LOG.warn("跳过不支持的 Agent 远程运行命令: messageId={}, commandId={}, action={}",
|
||||||
|
message.getMessageId(), command.getCommandId(), command.getAction());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
markFailureQuietly(command, e);
|
||||||
|
LOG.warn("Agent 远程运行命令处理失败: messageId={}, commandId={}",
|
||||||
|
message.getMessageId(), command.getCommandId(), e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
markSuccessQuietly(command);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String commandTopic() {
|
||||||
|
return properties.getCommandTopicPrefix() + ":" + properties.getInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void markSuccessQuietly(AgentRuntimeCommandMessage command) {
|
||||||
|
try {
|
||||||
|
resultRegistry.markSuccess(command.getCommandId());
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.error("Agent 远程运行命令成功结果写入失败: commandId={}", command.getCommandId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void markFailureQuietly(AgentRuntimeCommandMessage command, RuntimeException cause) {
|
||||||
|
try {
|
||||||
|
resultRegistry.markFailure(command.getCommandId(), cause.getMessage());
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.error("Agent 远程运行命令失败结果写入失败: commandId={}", command.getCommandId(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程恢复命令消息。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeCommandMessage {
|
||||||
|
|
||||||
|
private String commandId;
|
||||||
|
private String requestId;
|
||||||
|
private String resumeToken;
|
||||||
|
private AgentRuntimeCommandAction action;
|
||||||
|
private String reason;
|
||||||
|
private BigInteger operatorId;
|
||||||
|
private String userId;
|
||||||
|
private String targetNodeId;
|
||||||
|
private Date occurredAt;
|
||||||
|
|
||||||
|
public String getCommandId() {
|
||||||
|
return commandId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCommandId(String commandId) {
|
||||||
|
this.commandId = commandId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getRequestId() {
|
||||||
|
return requestId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRequestId(String requestId) {
|
||||||
|
this.requestId = requestId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getResumeToken() {
|
||||||
|
return resumeToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setResumeToken(String resumeToken) {
|
||||||
|
this.resumeToken = resumeToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
public AgentRuntimeCommandAction getAction() {
|
||||||
|
return action;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAction(AgentRuntimeCommandAction action) {
|
||||||
|
this.action = action;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getReason() {
|
||||||
|
return reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setReason(String reason) {
|
||||||
|
this.reason = reason;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BigInteger getOperatorId() {
|
||||||
|
return operatorId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOperatorId(BigInteger operatorId) {
|
||||||
|
this.operatorId = operatorId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getUserId() {
|
||||||
|
return userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUserId(String userId) {
|
||||||
|
this.userId = userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTargetNodeId() {
|
||||||
|
return targetNodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTargetNodeId(String targetNodeId) {
|
||||||
|
this.targetNodeId = targetNodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Date getOccurredAt() {
|
||||||
|
return occurredAt;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setOccurredAt(Date occurredAt) {
|
||||||
|
this.occurredAt = occurredAt;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.common.mq.core.MQMessage;
|
||||||
|
import tech.easyflow.common.mq.core.MQProducer;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.Date;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程命令生产者。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class AgentRuntimeCommandProducer {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeCommandProducer.class);
|
||||||
|
|
||||||
|
private final MQProducer mqProducer;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
private final AgentRuntimeCommandResultRegistry resultRegistry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 测试子类构造器。
|
||||||
|
*/
|
||||||
|
protected AgentRuntimeCommandProducer() {
|
||||||
|
this.mqProducer = null;
|
||||||
|
this.objectMapper = null;
|
||||||
|
this.properties = null;
|
||||||
|
this.resultRegistry = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 运行态远程命令生产者。
|
||||||
|
*
|
||||||
|
* @param mqProducer MQ 生产者
|
||||||
|
* @param objectMapper JSON 序列化器
|
||||||
|
* @param properties Agent 运行配置
|
||||||
|
* @param resultRegistry 远程命令结果注册表
|
||||||
|
*/
|
||||||
|
public AgentRuntimeCommandProducer(MQProducer mqProducer,
|
||||||
|
ObjectMapper objectMapper,
|
||||||
|
AgentRuntimeProperties properties,
|
||||||
|
AgentRuntimeCommandResultRegistry resultRegistry) {
|
||||||
|
this.mqProducer = mqProducer;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
this.properties = properties;
|
||||||
|
this.resultRegistry = resultRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 投递远程批准命令。
|
||||||
|
*
|
||||||
|
* @param targetNodeId 目标节点 ID
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @param operatorId 操作人 ID
|
||||||
|
* @param userId 用户 ID
|
||||||
|
*/
|
||||||
|
public void sendApprove(String targetNodeId,
|
||||||
|
String requestId,
|
||||||
|
String resumeToken,
|
||||||
|
BigInteger operatorId,
|
||||||
|
String userId) {
|
||||||
|
sendAndWait(targetNodeId, requestId, resumeToken, AgentRuntimeCommandAction.APPROVE, null, operatorId, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 投递远程拒绝命令。
|
||||||
|
*
|
||||||
|
* @param targetNodeId 目标节点 ID
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @param reason 拒绝原因
|
||||||
|
* @param operatorId 操作人 ID
|
||||||
|
* @param userId 用户 ID
|
||||||
|
*/
|
||||||
|
public void sendReject(String targetNodeId,
|
||||||
|
String requestId,
|
||||||
|
String resumeToken,
|
||||||
|
String reason,
|
||||||
|
BigInteger operatorId,
|
||||||
|
String userId) {
|
||||||
|
sendAndWait(targetNodeId, requestId, resumeToken, AgentRuntimeCommandAction.REJECT, reason, operatorId, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sendAndWait(String targetNodeId,
|
||||||
|
String requestId,
|
||||||
|
String resumeToken,
|
||||||
|
AgentRuntimeCommandAction action,
|
||||||
|
String reason,
|
||||||
|
BigInteger operatorId,
|
||||||
|
String userId) {
|
||||||
|
if (targetNodeId == null || targetNodeId.isBlank()) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
AgentRuntimeCommandMessage command = new AgentRuntimeCommandMessage();
|
||||||
|
command.setCommandId(UUID.randomUUID().toString());
|
||||||
|
command.setRequestId(requestId);
|
||||||
|
command.setResumeToken(resumeToken);
|
||||||
|
command.setAction(action);
|
||||||
|
command.setReason(reason);
|
||||||
|
command.setOperatorId(operatorId);
|
||||||
|
command.setUserId(userId);
|
||||||
|
command.setTargetNodeId(targetNodeId);
|
||||||
|
command.setOccurredAt(new Date());
|
||||||
|
|
||||||
|
MQMessage message = new MQMessage();
|
||||||
|
message.setMessageId(command.getCommandId());
|
||||||
|
message.setTopic(commandTopic(targetNodeId));
|
||||||
|
message.setKey(command.getCommandId());
|
||||||
|
message.setCreatedAt(command.getOccurredAt());
|
||||||
|
try {
|
||||||
|
message.setBody(objectMapper.writeValueAsString(command));
|
||||||
|
String recordId = mqProducer.send(message);
|
||||||
|
LOG.info("Agent 远程运行命令已投递: action={}, requestId={}, targetNodeId={}, recordId={}",
|
||||||
|
action, requestId, targetNodeId, recordId);
|
||||||
|
AgentRuntimeCommandResult result = resultRegistry.waitForResult(command.getCommandId());
|
||||||
|
if (!result.isSuccess()) {
|
||||||
|
throw new BusinessException(result.getMessage());
|
||||||
|
}
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new BusinessException("Agent 运行命令序列化失败");
|
||||||
|
} catch (BusinessException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.error("Agent 远程运行命令投递失败: action={}, requestId={}, targetNodeId={}",
|
||||||
|
action, requestId, targetNodeId, e);
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
} finally {
|
||||||
|
deleteResultQuietly(command.getCommandId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String commandTopic(String nodeId) {
|
||||||
|
return properties.getCommandTopicPrefix() + ":" + nodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void deleteResultQuietly(String commandId) {
|
||||||
|
try {
|
||||||
|
resultRegistry.deleteResult(commandId);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.warn("Agent 远程运行命令结果清理失败,等待 TTL 兜底: commandId={}", commandId, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程命令结果。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeCommandResult {
|
||||||
|
|
||||||
|
private boolean success;
|
||||||
|
private String message;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断命令是否执行成功。
|
||||||
|
*
|
||||||
|
* @return true 表示执行成功
|
||||||
|
*/
|
||||||
|
public boolean isSuccess() {
|
||||||
|
return success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置命令是否执行成功。
|
||||||
|
*
|
||||||
|
* @param success 是否执行成功
|
||||||
|
*/
|
||||||
|
public void setSuccess(boolean success) {
|
||||||
|
this.success = success;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取结果消息。
|
||||||
|
*
|
||||||
|
* @return 结果消息
|
||||||
|
*/
|
||||||
|
public String getMessage() {
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置结果消息。
|
||||||
|
*
|
||||||
|
* @param message 结果消息
|
||||||
|
*/
|
||||||
|
public void setMessage(String message) {
|
||||||
|
this.message = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态远程命令结果注册表。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class AgentRuntimeCommandResultRegistry {
|
||||||
|
|
||||||
|
private static final String RESULT_PREFIX = "easyflow:agent:runtime:command-result:";
|
||||||
|
private static final long POLL_INTERVAL_MILLIS = 50L;
|
||||||
|
|
||||||
|
private final StringRedisTemplate stringRedisTemplate;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 运行态远程命令结果注册表。
|
||||||
|
*
|
||||||
|
* @param stringRedisTemplate Redis 字符串模板
|
||||||
|
* @param objectMapper JSON 序列化器
|
||||||
|
* @param properties Agent 运行配置
|
||||||
|
*/
|
||||||
|
public AgentRuntimeCommandResultRegistry(StringRedisTemplate stringRedisTemplate,
|
||||||
|
ObjectMapper objectMapper,
|
||||||
|
AgentRuntimeProperties properties) {
|
||||||
|
this.stringRedisTemplate = stringRedisTemplate;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 写入成功结果。
|
||||||
|
*
|
||||||
|
* @param commandId 命令 ID
|
||||||
|
*/
|
||||||
|
public void markSuccess(String commandId) {
|
||||||
|
AgentRuntimeCommandResult result = new AgentRuntimeCommandResult();
|
||||||
|
result.setSuccess(true);
|
||||||
|
result.setMessage("OK");
|
||||||
|
writeResult(commandId, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 写入失败结果。
|
||||||
|
*
|
||||||
|
* @param commandId 命令 ID
|
||||||
|
* @param message 失败消息
|
||||||
|
*/
|
||||||
|
public void markFailure(String commandId, String message) {
|
||||||
|
AgentRuntimeCommandResult result = new AgentRuntimeCommandResult();
|
||||||
|
result.setSuccess(false);
|
||||||
|
result.setMessage(message == null || message.isBlank() ? "Agent 运行节点不可用,请重新发起对话" : message);
|
||||||
|
writeResult(commandId, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 等待远程命令结果。
|
||||||
|
*
|
||||||
|
* @param commandId 命令 ID
|
||||||
|
* @return 命令结果
|
||||||
|
*/
|
||||||
|
public AgentRuntimeCommandResult waitForResult(String commandId) {
|
||||||
|
long deadline = System.nanoTime() + properties.getCommandResultTimeout().toNanos();
|
||||||
|
while (System.nanoTime() <= deadline) {
|
||||||
|
AgentRuntimeCommandResult result = readResult(commandId);
|
||||||
|
if (result != null) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
sleep();
|
||||||
|
}
|
||||||
|
throw new BusinessException("Agent 运行节点响应超时,请稍后重试");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除远程命令结果。
|
||||||
|
*
|
||||||
|
* @param commandId 命令 ID
|
||||||
|
*/
|
||||||
|
public void deleteResult(String commandId) {
|
||||||
|
if (commandId == null || commandId.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stringRedisTemplate.delete(resultKey(commandId));
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeCommandResult readResult(String commandId) {
|
||||||
|
if (commandId == null || commandId.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String value = stringRedisTemplate.opsForValue().get(resultKey(commandId));
|
||||||
|
if (value == null || value.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return objectMapper.readValue(value, AgentRuntimeCommandResult.class);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new BusinessException("Agent 运行命令结果解析失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeResult(String commandId, AgentRuntimeCommandResult result) {
|
||||||
|
if (commandId == null || commandId.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
stringRedisTemplate.opsForValue().set(
|
||||||
|
resultKey(commandId),
|
||||||
|
objectMapper.writeValueAsString(result),
|
||||||
|
properties.getCommandResultTtl());
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalStateException("Agent 运行命令结果序列化失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resultKey(String commandId) {
|
||||||
|
return RESULT_PREFIX + commandId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sleep() {
|
||||||
|
try {
|
||||||
|
Thread.sleep(POLL_INTERVAL_MILLIS);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new BusinessException("Agent 运行节点响应等待被中断");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import jakarta.annotation.PostConstruct;
|
||||||
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行节点心跳维护器。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class AgentRuntimeNodeHeartbeat {
|
||||||
|
|
||||||
|
private static final Duration HEARTBEAT_TTL = Duration.ofSeconds(90);
|
||||||
|
|
||||||
|
private final AgentRuntimeRouteRegistry routeRegistry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 运行节点心跳维护器。
|
||||||
|
*
|
||||||
|
* @param routeRegistry Agent 运行态 Redis 路由注册表
|
||||||
|
*/
|
||||||
|
public AgentRuntimeNodeHeartbeat(AgentRuntimeRouteRegistry routeRegistry) {
|
||||||
|
this.routeRegistry = routeRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 启动时立即写入一次当前节点心跳。
|
||||||
|
*/
|
||||||
|
@PostConstruct
|
||||||
|
public void init() {
|
||||||
|
refresh();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 定期刷新当前节点心跳。
|
||||||
|
*/
|
||||||
|
@Scheduled(fixedDelayString = "${easyflow.agent.runtime.node-heartbeat-delay:30000}", initialDelay = 30000L)
|
||||||
|
public void refresh() {
|
||||||
|
routeRegistry.heartbeat(HEARTBEAT_TTL);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态 owner 路由。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeRoute {
|
||||||
|
|
||||||
|
private String nodeId;
|
||||||
|
private String bootId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 owner 节点 ID。
|
||||||
|
*
|
||||||
|
* @return owner 节点 ID
|
||||||
|
*/
|
||||||
|
public String getNodeId() {
|
||||||
|
return nodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 owner 节点 ID。
|
||||||
|
*
|
||||||
|
* @param nodeId owner 节点 ID
|
||||||
|
*/
|
||||||
|
public void setNodeId(String nodeId) {
|
||||||
|
this.nodeId = nodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 owner 启动代 ID。
|
||||||
|
*
|
||||||
|
* @return 启动代 ID
|
||||||
|
*/
|
||||||
|
public String getBootId() {
|
||||||
|
return bootId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 owner 启动代 ID。
|
||||||
|
*
|
||||||
|
* @param bootId 启动代 ID
|
||||||
|
*/
|
||||||
|
public void setBootId(String bootId) {
|
||||||
|
this.bootId = bootId;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,222 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 运行态 Redis 路由注册表。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class AgentRuntimeRouteRegistry {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeRouteRegistry.class);
|
||||||
|
|
||||||
|
private static final String REQUEST_ROUTE_PREFIX = "easyflow:agent:runtime:request:";
|
||||||
|
private static final String TOKEN_ROUTE_PREFIX = "easyflow:agent:runtime:resume-token:";
|
||||||
|
private static final String NODE_HEARTBEAT_PREFIX = "easyflow:agent:runtime:node:";
|
||||||
|
|
||||||
|
private final StringRedisTemplate stringRedisTemplate;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 运行态 Redis 路由注册表。
|
||||||
|
*
|
||||||
|
* @param stringRedisTemplate Redis 字符串模板
|
||||||
|
* @param properties Agent 运行配置
|
||||||
|
* @param objectMapper JSON 序列化器
|
||||||
|
*/
|
||||||
|
public AgentRuntimeRouteRegistry(StringRedisTemplate stringRedisTemplate,
|
||||||
|
AgentRuntimeProperties properties,
|
||||||
|
ObjectMapper objectMapper) {
|
||||||
|
this.stringRedisTemplate = stringRedisTemplate;
|
||||||
|
this.properties = properties;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 注册运行请求 owner 节点。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
*/
|
||||||
|
public void registerRun(String requestId) {
|
||||||
|
if (requestId == null || requestId.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stringRedisTemplate.opsForValue().set(requestKey(requestId), serializeRoute(currentRoute()), properties.getRouteTtl());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 注册恢复令牌与请求 ID 的关系。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
*/
|
||||||
|
public void registerResumeToken(String requestId, String resumeToken) {
|
||||||
|
if (requestId == null || requestId.isBlank() || resumeToken == null || resumeToken.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
stringRedisTemplate.opsForValue().set(tokenKey(resumeToken), requestId, properties.getRouteTtl());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询请求 ID 所属节点。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @return owner 节点 ID
|
||||||
|
*/
|
||||||
|
public String findOwnerNode(String requestId) {
|
||||||
|
AgentRuntimeRoute route = findOwnerRoute(requestId);
|
||||||
|
return route == null ? null : route.getNodeId();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询请求 ID 所属路由。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @return owner 路由
|
||||||
|
*/
|
||||||
|
public AgentRuntimeRoute findOwnerRoute(String requestId) {
|
||||||
|
if (requestId == null || requestId.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
String value = stringRedisTemplate.opsForValue().get(requestKey(requestId));
|
||||||
|
if (value == null || value.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return deserializeRoute(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据恢复令牌查询请求 ID。
|
||||||
|
*
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @return 请求 ID
|
||||||
|
*/
|
||||||
|
public String findRequestIdByResumeToken(String resumeToken) {
|
||||||
|
if (resumeToken == null || resumeToken.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return stringRedisTemplate.opsForValue().get(tokenKey(resumeToken));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除指定运行请求的路由。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
*/
|
||||||
|
public void removeRun(String requestId) {
|
||||||
|
if (requestId == null || requestId.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
deleteQuietly(requestKey(requestId));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除指定恢复令牌的路由。
|
||||||
|
*
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
*/
|
||||||
|
public void removeResumeToken(String resumeToken) {
|
||||||
|
if (resumeToken == null || resumeToken.isBlank()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
deleteQuietly(tokenKey(resumeToken));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前节点 ID。
|
||||||
|
*
|
||||||
|
* @return 当前节点 ID
|
||||||
|
*/
|
||||||
|
public String currentNodeId() {
|
||||||
|
return properties.getInstanceId();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 刷新当前节点存活心跳。
|
||||||
|
*
|
||||||
|
* @param ttl 心跳 TTL
|
||||||
|
*/
|
||||||
|
public void heartbeat(Duration ttl) {
|
||||||
|
stringRedisTemplate.opsForValue().set(nodeKey(properties.getInstanceId()), properties.getBootId(), ttl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询指定节点是否仍有存活心跳。
|
||||||
|
*
|
||||||
|
* @param nodeId 节点 ID
|
||||||
|
* @return true 表示节点心跳仍有效
|
||||||
|
*/
|
||||||
|
public boolean isNodeAlive(String nodeId) {
|
||||||
|
return currentNodeBootId(nodeId) != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询指定节点当前启动代 ID。
|
||||||
|
*
|
||||||
|
* @param nodeId 节点 ID
|
||||||
|
* @return 启动代 ID
|
||||||
|
*/
|
||||||
|
public String currentNodeBootId(String nodeId) {
|
||||||
|
if (nodeId == null || nodeId.isBlank()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return stringRedisTemplate.opsForValue().get(nodeKey(nodeId));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String requestKey(String requestId) {
|
||||||
|
return REQUEST_ROUTE_PREFIX + requestId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String tokenKey(String resumeToken) {
|
||||||
|
return TOKEN_ROUTE_PREFIX + resumeToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String nodeKey(String nodeId) {
|
||||||
|
return NODE_HEARTBEAT_PREFIX + nodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeRoute currentRoute() {
|
||||||
|
AgentRuntimeRoute route = new AgentRuntimeRoute();
|
||||||
|
route.setNodeId(properties.getInstanceId());
|
||||||
|
route.setBootId(properties.getBootId());
|
||||||
|
return route;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String serializeRoute(AgentRuntimeRoute route) {
|
||||||
|
try {
|
||||||
|
return objectMapper.writeValueAsString(route);
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalStateException("Agent 运行路由序列化失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeRoute deserializeRoute(String value) {
|
||||||
|
try {
|
||||||
|
if (value.trim().startsWith("{")) {
|
||||||
|
return objectMapper.readValue(value, AgentRuntimeRoute.class);
|
||||||
|
}
|
||||||
|
AgentRuntimeRoute legacyRoute = new AgentRuntimeRoute();
|
||||||
|
legacyRoute.setNodeId(value);
|
||||||
|
return legacyRoute;
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
throw new IllegalStateException("Agent 运行路由反序列化失败", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void deleteQuietly(String key) {
|
||||||
|
try {
|
||||||
|
stringRedisTemplate.delete(key);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.warn("清理 Agent 运行态 Redis 路由失败: key={}", key, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 聊天临时能力请求项。
|
||||||
|
*/
|
||||||
|
public class AgentChatCapability implements Serializable {
|
||||||
|
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
|
|
||||||
|
private String type;
|
||||||
|
private List<BigInteger> resourceIds = new ArrayList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取能力类型。
|
||||||
|
*
|
||||||
|
* @return 能力类型
|
||||||
|
*/
|
||||||
|
public String getType() {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置能力类型。
|
||||||
|
*
|
||||||
|
* @param type 能力类型
|
||||||
|
*/
|
||||||
|
public void setType(String type) {
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取资源 ID 列表。
|
||||||
|
*
|
||||||
|
* @return 资源 ID 列表
|
||||||
|
*/
|
||||||
|
public List<BigInteger> getResourceIds() {
|
||||||
|
return resourceIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置资源 ID 列表。
|
||||||
|
*
|
||||||
|
* @param resourceIds 资源 ID 列表
|
||||||
|
*/
|
||||||
|
public void setResourceIds(List<BigInteger> resourceIds) {
|
||||||
|
this.resourceIds = resourceIds == null ? new ArrayList<>() : new ArrayList<>(resourceIds);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import tech.easyflow.agent.entity.Agent;
|
||||||
|
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||||
|
import tech.easyflow.ai.entity.DocumentCollection;
|
||||||
|
import tech.easyflow.ai.enums.PublishStatus;
|
||||||
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
|
import tech.easyflow.common.entity.LoginAccount;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
import tech.easyflow.system.enums.CategoryResourceType;
|
||||||
|
import tech.easyflow.system.enums.ResourceAction;
|
||||||
|
import tech.easyflow.system.service.ResourceAccessService;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 聊天临时能力编排服务。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class AgentChatCapabilityService {
|
||||||
|
|
||||||
|
private static final String KNOWLEDGE_CAPABILITY_TYPE = "KNOWLEDGE";
|
||||||
|
private static final String DEFAULT_RETRIEVAL_MODE = "HYBRID";
|
||||||
|
private static final int MAX_EXTRA_KNOWLEDGE_COUNT = 3;
|
||||||
|
|
||||||
|
private final DocumentCollectionService documentCollectionService;
|
||||||
|
private final ResourceAccessService resourceAccessService;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 聊天临时能力编排服务。
|
||||||
|
*
|
||||||
|
* @param documentCollectionService 知识库服务
|
||||||
|
* @param resourceAccessService 资源访问服务
|
||||||
|
* @param objectMapper 对象复制工具
|
||||||
|
*/
|
||||||
|
public AgentChatCapabilityService(DocumentCollectionService documentCollectionService,
|
||||||
|
ResourceAccessService resourceAccessService,
|
||||||
|
ObjectMapper objectMapper) {
|
||||||
|
this.documentCollectionService = documentCollectionService;
|
||||||
|
this.resourceAccessService = resourceAccessService;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将临时聊天能力合并到运行时 Agent 定义中。
|
||||||
|
*
|
||||||
|
* @param agent 已发布 Agent 运行视图
|
||||||
|
* @param capabilities 临时能力请求
|
||||||
|
* @param account 当前登录账号
|
||||||
|
* @return 能力解析结果
|
||||||
|
*/
|
||||||
|
public AgentChatCapabilityResolution apply(Agent agent,
|
||||||
|
List<AgentChatCapability> capabilities,
|
||||||
|
LoginAccount account) {
|
||||||
|
List<BigInteger> extraKnowledgeIds = resolveKnowledgeIds(capabilities);
|
||||||
|
boolean knowledgeCapabilityProvided = hasKnowledgeCapability(capabilities);
|
||||||
|
if (agent == null || extraKnowledgeIds.isEmpty()) {
|
||||||
|
return new AgentChatCapabilityResolution(agent, extraKnowledgeIds, knowledgeCapabilityProvided);
|
||||||
|
}
|
||||||
|
Agent runtimeAgent = objectMapper.convertValue(agent, Agent.class);
|
||||||
|
List<AgentKnowledgeBinding> mergedBindings = new ArrayList<>();
|
||||||
|
Set<BigInteger> existingKnowledgeIds = new LinkedHashSet<>();
|
||||||
|
if (runtimeAgent.getKnowledgeBindings() != null) {
|
||||||
|
for (AgentKnowledgeBinding binding : runtimeAgent.getKnowledgeBindings()) {
|
||||||
|
if (binding == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
mergedBindings.add(binding);
|
||||||
|
if (Boolean.TRUE.equals(binding.getEnabled()) && binding.getKnowledgeId() != null) {
|
||||||
|
existingKnowledgeIds.add(binding.getKnowledgeId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int sortNo = mergedBindings.size();
|
||||||
|
for (BigInteger knowledgeId : extraKnowledgeIds) {
|
||||||
|
if (existingKnowledgeIds.contains(knowledgeId)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
DocumentCollection knowledge = documentCollectionService.getById(knowledgeId);
|
||||||
|
validateKnowledge(knowledge);
|
||||||
|
resourceAccessService.assertAccess(
|
||||||
|
CategoryResourceType.KNOWLEDGE,
|
||||||
|
knowledge,
|
||||||
|
ResourceAction.USE,
|
||||||
|
"无权限使用所选知识库"
|
||||||
|
);
|
||||||
|
AgentKnowledgeBinding binding = new AgentKnowledgeBinding();
|
||||||
|
binding.setTenantId(account == null ? runtimeAgent.getTenantId() : account.getTenantId());
|
||||||
|
binding.setAgentId(runtimeAgent.getId());
|
||||||
|
binding.setKnowledgeId(knowledgeId);
|
||||||
|
binding.setRetrievalMode(DEFAULT_RETRIEVAL_MODE);
|
||||||
|
binding.setEnabled(true);
|
||||||
|
binding.setSortNo(sortNo++);
|
||||||
|
mergedBindings.add(binding);
|
||||||
|
existingKnowledgeIds.add(knowledgeId);
|
||||||
|
}
|
||||||
|
runtimeAgent.setKnowledgeBindings(mergedBindings);
|
||||||
|
return new AgentChatCapabilityResolution(runtimeAgent, extraKnowledgeIds, knowledgeCapabilityProvided);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从能力列表提取知识库 ID。
|
||||||
|
*
|
||||||
|
* @param capabilities 临时能力列表
|
||||||
|
* @return 已去重知识库 ID
|
||||||
|
*/
|
||||||
|
public List<BigInteger> resolveKnowledgeIds(List<AgentChatCapability> capabilities) {
|
||||||
|
if (capabilities == null || capabilities.isEmpty()) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
LinkedHashSet<BigInteger> ids = new LinkedHashSet<>();
|
||||||
|
for (AgentChatCapability capability : capabilities) {
|
||||||
|
if (capability == null || !isKnowledgeCapability(capability.getType())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (capability.getResourceIds() == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (BigInteger resourceId : capability.getResourceIds()) {
|
||||||
|
if (resourceId != null) {
|
||||||
|
ids.add(resourceId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ids.size() > MAX_EXTRA_KNOWLEDGE_COUNT) {
|
||||||
|
throw new BusinessException("临时知识库最多选择 3 个");
|
||||||
|
}
|
||||||
|
return new ArrayList<>(ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isKnowledgeCapability(String type) {
|
||||||
|
return Objects.equals(KNOWLEDGE_CAPABILITY_TYPE, type == null ? null : type.trim().toUpperCase());
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean hasKnowledgeCapability(List<AgentChatCapability> capabilities) {
|
||||||
|
if (capabilities == null || capabilities.isEmpty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (AgentChatCapability capability : capabilities) {
|
||||||
|
if (capability != null && isKnowledgeCapability(capability.getType())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validateKnowledge(DocumentCollection knowledge) {
|
||||||
|
if (knowledge == null) {
|
||||||
|
throw new BusinessException("所选知识库不存在");
|
||||||
|
}
|
||||||
|
if (PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) {
|
||||||
|
throw new BusinessException("所选知识库未发布,无法用于聊天");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 聊天临时能力解析结果。
|
||||||
|
*
|
||||||
|
* @param agent 合并临时能力后的运行时 Agent
|
||||||
|
* @param extraKnowledgeIds 本次选择的临时知识库 ID
|
||||||
|
* @param knowledgeCapabilityProvided 请求是否显式传入知识库能力
|
||||||
|
*/
|
||||||
|
public record AgentChatCapabilityResolution(Agent agent,
|
||||||
|
List<BigInteger> extraKnowledgeIds,
|
||||||
|
boolean knowledgeCapabilityProvided) {
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package tech.easyflow.agent.runtime;
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Agent 管理端运行请求。
|
* Agent 管理端运行请求。
|
||||||
@@ -10,6 +12,7 @@ public class AgentChatRequest {
|
|||||||
private BigInteger agentId;
|
private BigInteger agentId;
|
||||||
private BigInteger sessionId;
|
private BigInteger sessionId;
|
||||||
private String prompt;
|
private String prompt;
|
||||||
|
private List<AgentChatCapability> capabilities = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 Agent ID。
|
* 获取 Agent ID。
|
||||||
@@ -52,4 +55,22 @@ public class AgentChatRequest {
|
|||||||
* @param prompt 用户输入
|
* @param prompt 用户输入
|
||||||
*/
|
*/
|
||||||
public void setPrompt(String prompt) { this.prompt = prompt; }
|
public void setPrompt(String prompt) { this.prompt = prompt; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取本次聊天启用的临时能力。
|
||||||
|
*
|
||||||
|
* @return 临时能力列表
|
||||||
|
*/
|
||||||
|
public List<AgentChatCapability> getCapabilities() {
|
||||||
|
return capabilities;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置本次聊天启用的临时能力。
|
||||||
|
*
|
||||||
|
* @param capabilities 临时能力列表
|
||||||
|
*/
|
||||||
|
public void setCapabilities(List<AgentChatCapability> capabilities) {
|
||||||
|
this.capabilities = capabilities == null ? new ArrayList<>() : new ArrayList<>(capabilities);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ import com.easyagents.agent.runtime.AgentResumeRequest;
|
|||||||
import com.easyagents.agent.runtime.AgentRuntime;
|
import com.easyagents.agent.runtime.AgentRuntime;
|
||||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||||
import com.easyagents.agent.runtime.hitl.AgentResumeToken;
|
import com.easyagents.agent.runtime.hitl.AgentResumeToken;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import reactor.core.Disposable;
|
import reactor.core.Disposable;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||||
@@ -25,11 +29,24 @@ import java.util.function.Consumer;
|
|||||||
@Component
|
@Component
|
||||||
public class AgentRunRegistry {
|
public class AgentRunRegistry {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(AgentRunRegistry.class);
|
||||||
|
|
||||||
private final Map<String, AgentRunContext> runs = new ConcurrentHashMap<>();
|
private final Map<String, AgentRunContext> runs = new ConcurrentHashMap<>();
|
||||||
private final Map<String, String> sessionRuns = new ConcurrentHashMap<>();
|
private final Map<String, String> sessionRuns = new ConcurrentHashMap<>();
|
||||||
private final Map<String, String> resumeTokenIndex = new ConcurrentHashMap<>();
|
private final Map<String, String> resumeTokenIndex = new ConcurrentHashMap<>();
|
||||||
private final Map<String, Set<String>> requestTokens = new ConcurrentHashMap<>();
|
private final Map<String, Set<String>> requestTokens = new ConcurrentHashMap<>();
|
||||||
private final Map<String, RunOwner> owners = new ConcurrentHashMap<>();
|
private final Map<String, RunOwner> owners = new ConcurrentHashMap<>();
|
||||||
|
private AgentRuntimeRouteRegistry routeRegistry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 运行态 Redis 路由注册表。
|
||||||
|
*
|
||||||
|
* @param routeRegistry Redis 路由注册表
|
||||||
|
*/
|
||||||
|
@Autowired(required = false)
|
||||||
|
public void setRouteRegistry(AgentRuntimeRouteRegistry routeRegistry) {
|
||||||
|
this.routeRegistry = routeRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 注册运行态。
|
* 注册运行态。
|
||||||
@@ -53,6 +70,9 @@ public class AgentRunRegistry {
|
|||||||
throw new BusinessException("当前 Agent 运行请求已存在");
|
throw new BusinessException("当前 Agent 运行请求已存在");
|
||||||
}
|
}
|
||||||
owners.put(context.requestId(), context.owner());
|
owners.put(context.requestId(), context.owner());
|
||||||
|
if (routeRegistry != null) {
|
||||||
|
routeRegistry.registerRun(context.requestId());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -122,6 +142,9 @@ public class AgentRunRegistry {
|
|||||||
if (requestId != null && resumeToken != null && !resumeToken.isBlank()) {
|
if (requestId != null && resumeToken != null && !resumeToken.isBlank()) {
|
||||||
resumeTokenIndex.put(resumeToken, requestId);
|
resumeTokenIndex.put(resumeToken, requestId);
|
||||||
requestTokens.computeIfAbsent(requestId, ignored -> ConcurrentHashMap.newKeySet()).add(resumeToken);
|
requestTokens.computeIfAbsent(requestId, ignored -> ConcurrentHashMap.newKeySet()).add(resumeToken);
|
||||||
|
if (routeRegistry != null) {
|
||||||
|
routeRegistry.registerResumeToken(requestId, resumeToken);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,11 +161,20 @@ public class AgentRunRegistry {
|
|||||||
if (context != null) {
|
if (context != null) {
|
||||||
sessionRuns.remove(context.sessionId(), requestId);
|
sessionRuns.remove(context.sessionId(), requestId);
|
||||||
context.releaseLock();
|
context.releaseLock();
|
||||||
|
context.closeRuntime();
|
||||||
}
|
}
|
||||||
owners.remove(requestId);
|
owners.remove(requestId);
|
||||||
Set<String> tokens = requestTokens.remove(requestId);
|
Set<String> tokens = requestTokens.remove(requestId);
|
||||||
if (tokens != null) {
|
if (tokens != null) {
|
||||||
tokens.forEach(resumeTokenIndex::remove);
|
tokens.forEach(token -> {
|
||||||
|
resumeTokenIndex.remove(token);
|
||||||
|
if (routeRegistry != null) {
|
||||||
|
routeRegistry.removeResumeToken(token);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (routeRegistry != null) {
|
||||||
|
routeRegistry.removeRun(requestId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,6 +242,23 @@ public class AgentRunRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 当前恢复目标是否为草稿试运行。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID,可为空
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @return true 表示目标为草稿试运行
|
||||||
|
*/
|
||||||
|
public boolean isDraftResumeTarget(String requestId, String resumeToken) {
|
||||||
|
try {
|
||||||
|
String resolvedRequestId = resolveRequestId(requestId, resumeToken);
|
||||||
|
AgentRunContext context = runs.get(resolvedRequestId);
|
||||||
|
return context != null && !context.persistChatlog();
|
||||||
|
} catch (BusinessException ignored) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void submit(String requestId, String resumeToken, String userId, boolean approved, String reason) {
|
private void submit(String requestId, String resumeToken, String userId, boolean approved, String reason) {
|
||||||
submit(requestId, resumeToken, userId, approved, reason, null);
|
submit(requestId, resumeToken, userId, approved, reason, null);
|
||||||
}
|
}
|
||||||
@@ -235,6 +284,9 @@ public class AgentRunRegistry {
|
|||||||
tokens.remove(resumeToken);
|
tokens.remove(resumeToken);
|
||||||
}
|
}
|
||||||
resumeTokenIndex.remove(resumeToken);
|
resumeTokenIndex.remove(resumeToken);
|
||||||
|
if (routeRegistry != null) {
|
||||||
|
routeRegistry.removeResumeToken(resumeToken);
|
||||||
|
}
|
||||||
AgentResumeToken token = new AgentResumeToken();
|
AgentResumeToken token = new AgentResumeToken();
|
||||||
token.setValue(resumeToken);
|
token.setValue(resumeToken);
|
||||||
AgentResumeRequest request = new AgentResumeRequest();
|
AgentResumeRequest request = new AgentResumeRequest();
|
||||||
@@ -430,6 +482,15 @@ public class AgentRunRegistry {
|
|||||||
return suspended.get();
|
return suspended.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 当前运行是否持久化聊天日志与运行态。
|
||||||
|
*
|
||||||
|
* @return true 表示正式聊天持久化运行
|
||||||
|
*/
|
||||||
|
public boolean persistChatlog() {
|
||||||
|
return persistChatlog;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 绑定运行订阅。
|
* 绑定运行订阅。
|
||||||
*
|
*
|
||||||
@@ -477,6 +538,18 @@ public class AgentRunRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关闭底层运行时并释放资源。
|
||||||
|
*/
|
||||||
|
public void closeRuntime() {
|
||||||
|
try {
|
||||||
|
runtime.close();
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.warn("Close Agent runtime failed, requestId={}, sessionId={}, message={}",
|
||||||
|
requestId, sessionId, e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 通过同一个 runtime 恢复挂起运行,事件继续写入原 SSE。
|
* 通过同一个 runtime 恢复挂起运行,事件继续写入原 SSE。
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ import tech.easyflow.agent.entity.Agent;
|
|||||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
import tech.easyflow.agent.enums.AgentToolType;
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandAction;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandProducer;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRoute;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||||
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||||
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||||
@@ -66,16 +70,22 @@ public class AgentRunService {
|
|||||||
@Resource
|
@Resource
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
@Resource
|
@Resource
|
||||||
private AgentDefinitionCompiler agentDefinitionCompiler;
|
private AgentRuntimeCompiler agentRuntimeCompiler;
|
||||||
@Resource
|
@Resource
|
||||||
private AgentRuntimeFactory agentRuntimeFactory;
|
private AgentRuntimeFactory agentRuntimeFactory;
|
||||||
@Resource
|
@Resource
|
||||||
private AgentSessionStore agentSessionStore;
|
private AgentChatCapabilityService agentChatCapabilityService;
|
||||||
@Resource
|
@Resource
|
||||||
private EasyFlowAgentSessionStore easyFlowAgentSessionStore;
|
private EasyFlowAgentSessionStore easyFlowAgentSessionStore;
|
||||||
@Resource
|
@Resource
|
||||||
|
private AgentSessionStore draftAgentSessionStore;
|
||||||
|
@Resource
|
||||||
private AgentRunRegistry agentRunRegistry;
|
private AgentRunRegistry agentRunRegistry;
|
||||||
@Resource
|
@Resource
|
||||||
|
private AgentRuntimeRouteRegistry agentRuntimeRouteRegistry;
|
||||||
|
@Resource
|
||||||
|
private AgentRuntimeCommandProducer agentRuntimeCommandProducer;
|
||||||
|
@Resource
|
||||||
private AgentRunLock agentRunLock;
|
private AgentRunLock agentRunLock;
|
||||||
@Resource
|
@Resource
|
||||||
private AgentHitlPendingService agentHitlPendingService;
|
private AgentHitlPendingService agentHitlPendingService;
|
||||||
@@ -121,14 +131,20 @@ public class AgentRunService {
|
|||||||
ChatSessionSummary existingSession = resolveExistingSession(account, sessionId, chatRequest.getAgentId());
|
ChatSessionSummary existingSession = resolveExistingSession(account, sessionId, chatRequest.getAgentId());
|
||||||
// 获取 Agent 发布快照
|
// 获取 Agent 发布快照
|
||||||
Agent agent = agentService.getPublishedView(chatRequest.getAgentId());
|
Agent agent = agentService.getPublishedView(chatRequest.getAgentId());
|
||||||
|
AgentChatCapabilityService.AgentChatCapabilityResolution capabilityResolution =
|
||||||
|
agentChatCapabilityService.apply(agent, chatRequest.getCapabilities(), account);
|
||||||
|
agent = capabilityResolution.agent();
|
||||||
String requestId = UUID.randomUUID().toString();
|
String requestId = UUID.randomUUID().toString();
|
||||||
String traceId = UUID.randomUUID().toString();
|
String traceId = UUID.randomUUID().toString();
|
||||||
// 组建会话上下文必要信息
|
// 组建会话上下文必要信息
|
||||||
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, sessionId, chatRequest.getPrompt(), account);
|
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, sessionId, chatRequest.getPrompt(), account);
|
||||||
|
if (capabilityResolution.knowledgeCapabilityProvided()) {
|
||||||
|
chatContext.getExt().put(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS, capabilityResolution.extraKnowledgeIds());
|
||||||
|
}
|
||||||
applyFormalSessionTitle(chatContext, chatRequest.getPrompt(), existingSession);
|
applyFormalSessionTitle(chatContext, chatRequest.getPrompt(), existingSession);
|
||||||
// 执行对话
|
// 执行对话
|
||||||
return run(agent, chatRequest.getPrompt(), requestId, traceId, sessionId.toString(),
|
return run(agent, chatRequest.getPrompt(), requestId, traceId, sessionId.toString(),
|
||||||
ASSISTANT_CODE, chatContext, true);
|
ASSISTANT_CODE, chatContext, true, easyFlowAgentSessionStore);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -152,7 +168,7 @@ public class AgentRunService {
|
|||||||
String traceId = UUID.randomUUID().toString();
|
String traceId = UUID.randomUUID().toString();
|
||||||
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, chatSessionId, draftRequest.getPrompt(), account, DRAFT_ASSISTANT_CODE);
|
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, chatSessionId, draftRequest.getPrompt(), account, DRAFT_ASSISTANT_CODE);
|
||||||
return run(agent, draftRequest.getPrompt(), requestId, traceId, runtimeSessionId,
|
return run(agent, draftRequest.getPrompt(), requestId, traceId, runtimeSessionId,
|
||||||
DRAFT_ASSISTANT_CODE, chatContext, false);
|
DRAFT_ASSISTANT_CODE, chatContext, false, draftAgentSessionStore);
|
||||||
}
|
}
|
||||||
|
|
||||||
private SseEmitter run(Agent agent,
|
private SseEmitter run(Agent agent,
|
||||||
@@ -162,7 +178,8 @@ public class AgentRunService {
|
|||||||
String runtimeSessionId,
|
String runtimeSessionId,
|
||||||
String assistantCode,
|
String assistantCode,
|
||||||
ChatRuntimeContext chatContext,
|
ChatRuntimeContext chatContext,
|
||||||
boolean persistChatlog) {
|
boolean persistChatlog,
|
||||||
|
AgentSessionStore runtimeSessionStore) {
|
||||||
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
|
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
|
||||||
// 获取会话锁
|
// 获取会话锁
|
||||||
AgentRunLock.Handle lockHandle = acquireRunLock(agent, runtimeSessionId);
|
AgentRunLock.Handle lockHandle = acquireRunLock(agent, runtimeSessionId);
|
||||||
@@ -178,7 +195,7 @@ public class AgentRunService {
|
|||||||
chatRuntimeManager.recordUserMessage(chatContext, buildUserRuntimeMessage(chatContext, prompt));
|
chatRuntimeManager.recordUserMessage(chatContext, buildUserRuntimeMessage(chatContext, prompt));
|
||||||
}
|
}
|
||||||
threadPoolTaskExecutor.execute(() -> startRuntime(agent, prompt, requestId, traceId, runtimeSessionId,
|
threadPoolTaskExecutor.execute(() -> startRuntime(agent, prompt, requestId, traceId, runtimeSessionId,
|
||||||
assistantCode, chatContext, chatSseEmitter, persistChatlog, lockHandle));
|
assistantCode, chatContext, chatSseEmitter, persistChatlog, runtimeSessionStore, lockHandle));
|
||||||
submitted = true;
|
submitted = true;
|
||||||
return chatSseEmitter.getEmitter();
|
return chatSseEmitter.getEmitter();
|
||||||
} finally {
|
} finally {
|
||||||
@@ -202,11 +219,12 @@ public class AgentRunService {
|
|||||||
throw new BusinessException("仅允许清理 Agent 草稿试运行会话");
|
throw new BusinessException("仅允许清理 Agent 草稿试运行会话");
|
||||||
}
|
}
|
||||||
LoginAccount account = requireCurrentLoginAccount();
|
LoginAccount account = requireCurrentLoginAccount();
|
||||||
agentRunRegistry.cancelSession(sessionId, account.getId() == null ? null : account.getId().toString());
|
clearDraftSessionInternal(sessionId, account.getId() == null ? null : account.getId().toString());
|
||||||
agentSessionStore.delete(sessionId);
|
}
|
||||||
if (agentHitlPendingService != null) {
|
|
||||||
agentHitlPendingService.deleteByRuntimeSessionId(sessionId);
|
private void clearDraftSessionInternal(String sessionId, String userId) {
|
||||||
}
|
agentRunRegistry.cancelSession(sessionId, userId);
|
||||||
|
draftAgentSessionStore.delete(sessionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -217,9 +235,32 @@ public class AgentRunService {
|
|||||||
*/
|
*/
|
||||||
public void approve(String requestId, String resumeToken) {
|
public void approve(String requestId, String resumeToken) {
|
||||||
LoginAccount account = requireCurrentLoginAccount();
|
LoginAccount account = requireCurrentLoginAccount();
|
||||||
String userId = account.getId() == null ? null : account.getId().toString();
|
approveRuntime(requestId, resumeToken, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void approveRuntime(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||||
|
if (!agentRunRegistry.containsResumeTarget(requestId, resumeToken)) {
|
||||||
|
dispatchRemoteRuntimeCommand(requestId, resumeToken, AgentRuntimeCommandAction.APPROVE, null, operatorId, userId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
approveRuntimeLocal(requestId, resumeToken, operatorId, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在当前节点批准工具执行。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @param operatorId 操作人 ID
|
||||||
|
* @param userId 用户 ID
|
||||||
|
*/
|
||||||
|
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||||
|
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||||
|
agentRunRegistry.approve(requestId, resumeToken, userId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
agentRunRegistry.approve(requestId, resumeToken, userId,
|
agentRunRegistry.approve(requestId, resumeToken, userId,
|
||||||
() -> agentHitlPendingService.approve(resumeToken, account.getId()));
|
() -> agentHitlPendingService.approve(resumeToken, operatorId));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -231,9 +272,73 @@ public class AgentRunService {
|
|||||||
*/
|
*/
|
||||||
public void reject(String requestId, String resumeToken, String reason) {
|
public void reject(String requestId, String resumeToken, String reason) {
|
||||||
LoginAccount account = requireCurrentLoginAccount();
|
LoginAccount account = requireCurrentLoginAccount();
|
||||||
String userId = account.getId() == null ? null : account.getId().toString();
|
rejectRuntime(requestId, resumeToken, reason, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void rejectRuntime(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||||
|
if (!agentRunRegistry.containsResumeTarget(requestId, resumeToken)) {
|
||||||
|
dispatchRemoteRuntimeCommand(requestId, resumeToken, AgentRuntimeCommandAction.REJECT, reason, operatorId, userId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rejectRuntimeLocal(requestId, resumeToken, reason, operatorId, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在当前节点拒绝工具执行。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
* @param resumeToken 恢复令牌
|
||||||
|
* @param reason 拒绝原因
|
||||||
|
* @param operatorId 操作人 ID
|
||||||
|
* @param userId 用户 ID
|
||||||
|
*/
|
||||||
|
public void rejectRuntimeLocal(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||||
|
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||||
|
agentRunRegistry.reject(requestId, resumeToken, userId, reason);
|
||||||
|
return;
|
||||||
|
}
|
||||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason,
|
agentRunRegistry.reject(requestId, resumeToken, userId, reason,
|
||||||
() -> agentHitlPendingService.reject(resumeToken, account.getId(), reason));
|
() -> agentHitlPendingService.reject(resumeToken, operatorId, reason));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void dispatchRemoteRuntimeCommand(String requestId,
|
||||||
|
String resumeToken,
|
||||||
|
AgentRuntimeCommandAction action,
|
||||||
|
String reason,
|
||||||
|
BigInteger operatorId,
|
||||||
|
String userId) {
|
||||||
|
String resolvedRequestId = resolveRequestIdForRemoteDispatch(requestId, resumeToken);
|
||||||
|
AgentRuntimeRoute ownerRoute = agentRuntimeRouteRegistry.findOwnerRoute(resolvedRequestId);
|
||||||
|
String ownerNodeId = ownerRoute == null ? null : ownerRoute.getNodeId();
|
||||||
|
if (ownerNodeId == null || ownerNodeId.isBlank()) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
if (ownerNodeId.equals(agentRuntimeRouteRegistry.currentNodeId())) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
if (!agentRuntimeRouteRegistry.isNodeAlive(ownerNodeId)) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
String currentOwnerBootId = agentRuntimeRouteRegistry.currentNodeBootId(ownerNodeId);
|
||||||
|
if (ownerRoute.getBootId() == null || !ownerRoute.getBootId().equals(currentOwnerBootId)) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
if (action == AgentRuntimeCommandAction.APPROVE) {
|
||||||
|
agentRuntimeCommandProducer.sendApprove(ownerNodeId, resolvedRequestId, resumeToken, operatorId, userId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
agentRuntimeCommandProducer.sendReject(ownerNodeId, resolvedRequestId, resumeToken, reason, operatorId, userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveRequestIdForRemoteDispatch(String requestId, String resumeToken) {
|
||||||
|
if (requestId != null && !requestId.isBlank()) {
|
||||||
|
return requestId;
|
||||||
|
}
|
||||||
|
String resolvedRequestId = agentRuntimeRouteRegistry.findRequestIdByResumeToken(resumeToken);
|
||||||
|
if (resolvedRequestId == null || resolvedRequestId.isBlank()) {
|
||||||
|
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||||
|
}
|
||||||
|
return resolvedRequestId;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void startRuntime(Agent agent,
|
private void startRuntime(Agent agent,
|
||||||
@@ -245,6 +350,7 @@ public class AgentRunService {
|
|||||||
ChatRuntimeContext chatContext,
|
ChatRuntimeContext chatContext,
|
||||||
ChatSseEmitter chatSseEmitter,
|
ChatSseEmitter chatSseEmitter,
|
||||||
boolean persistChatlog,
|
boolean persistChatlog,
|
||||||
|
AgentSessionStore runtimeSessionStore,
|
||||||
AgentRunLock.Handle initialLockHandle) {
|
AgentRunLock.Handle initialLockHandle) {
|
||||||
AtomicBoolean finished = new AtomicBoolean(false);
|
AtomicBoolean finished = new AtomicBoolean(false);
|
||||||
StringBuilder answer = new StringBuilder();
|
StringBuilder answer = new StringBuilder();
|
||||||
@@ -254,8 +360,10 @@ public class AgentRunService {
|
|||||||
assistantAccumulator, finished, persistChatlog);
|
assistantAccumulator, finished, persistChatlog);
|
||||||
AgentRunLock.Handle lockHandle = initialLockHandle;
|
AgentRunLock.Handle lockHandle = initialLockHandle;
|
||||||
try {
|
try {
|
||||||
bindAgentSession(agent, runtimeSessionId, chatContext);
|
if (persistChatlog) {
|
||||||
AgentRuntimeBundle bundle = agentDefinitionCompiler.compile(agent);
|
bindAgentSession(agent, runtimeSessionId, chatContext);
|
||||||
|
}
|
||||||
|
AgentRuntimeBundle bundle = agentRuntimeCompiler.compile(agent);
|
||||||
AgentRuntime runtime = agentRuntimeFactory.create();
|
AgentRuntime runtime = agentRuntimeFactory.create();
|
||||||
// 会话初始化请求
|
// 会话初始化请求
|
||||||
AgentInitRequest request = new AgentInitRequest();
|
AgentInitRequest request = new AgentInitRequest();
|
||||||
@@ -264,7 +372,7 @@ public class AgentRunService {
|
|||||||
request.setRuntimeContext(buildAgentRuntimeContext(chatContext, traceId, runtimeSessionId));
|
request.setRuntimeContext(buildAgentRuntimeContext(chatContext, traceId, runtimeSessionId));
|
||||||
request.setToolInvokers(bundle.getToolInvokers());
|
request.setToolInvokers(bundle.getToolInvokers());
|
||||||
request.setKnowledgeRetrievers(bundle.getKnowledgeRetrievers());
|
request.setKnowledgeRetrievers(bundle.getKnowledgeRetrievers());
|
||||||
request.setSessionStore(agentSessionStore);
|
request.setSessionStore(runtimeSessionStore);
|
||||||
request.getMetadata().put("assistantCode", assistantCode);
|
request.getMetadata().put("assistantCode", assistantCode);
|
||||||
runtime.init(request);
|
runtime.init(request);
|
||||||
// 注册会话运行时管理
|
// 注册会话运行时管理
|
||||||
@@ -338,20 +446,20 @@ public class AgentRunService {
|
|||||||
return agentRunLock.acquire(agent == null ? null : agent.getId(), runtimeSessionId);
|
return agentRunLock.acquire(agent == null ? null : agent.getId(), runtimeSessionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||||
if (agentRunEventRecorder != null) {
|
if (persistChatlog && agentRunEventRecorder != null) {
|
||||||
agentRunEventRecorder.record(requestId, chatContext, event);
|
agentRunEventRecorder.record(requestId, chatContext, event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||||
if (agentHitlPendingService != null) {
|
if (persistChatlog && agentHitlPendingService != null) {
|
||||||
agentHitlPendingService.recordApprovalRequired(requestId, chatContext, event);
|
agentHitlPendingService.recordApprovalRequired(requestId, chatContext, event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void cancelPending(String requestId, String reason) {
|
private void cancelPending(String requestId, String reason, boolean persistChatlog) {
|
||||||
if (agentHitlPendingService != null) {
|
if (persistChatlog && agentHitlPendingService != null) {
|
||||||
agentHitlPendingService.cancelByRequestId(requestId, reason);
|
agentHitlPendingService.cancelByRequestId(requestId, reason);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -389,7 +497,7 @@ public class AgentRunService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
agentRunRegistry.remove(requestId);
|
agentRunRegistry.remove(requestId);
|
||||||
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消");
|
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消", persistChatlog);
|
||||||
if (!persistChatlog) {
|
if (!persistChatlog) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -412,7 +520,7 @@ public class AgentRunService {
|
|||||||
if (event == null || event.getEventType() == null) {
|
if (event == null || event.getEventType() == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
recordRuntimeEvent(requestId, chatContext, event);
|
recordRuntimeEvent(requestId, chatContext, event, persistChatlog);
|
||||||
if (event.getEventType() == AgentRuntimeEventType.MESSAGE_DELTA) {
|
if (event.getEventType() == AgentRuntimeEventType.MESSAGE_DELTA) {
|
||||||
String text = stringPayload(event, "text");
|
String text = stringPayload(event, "text");
|
||||||
if (text != null) {
|
if (text != null) {
|
||||||
@@ -440,21 +548,29 @@ public class AgentRunService {
|
|||||||
if (event.getEventType() == AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED) {
|
if (event.getEventType() == AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED) {
|
||||||
String resumeToken = stringPayload(event, "resumeToken");
|
String resumeToken = stringPayload(event, "resumeToken");
|
||||||
agentRunRegistry.registerResumeToken(requestId, resumeToken);
|
agentRunRegistry.registerResumeToken(requestId, resumeToken);
|
||||||
recordApprovalRequired(requestId, chatContext, event);
|
recordApprovalRequired(requestId, chatContext, event, persistChatlog);
|
||||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.FORM_REQUEST, buildToolHitlPayload(requestId, event))) {
|
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.FORM_REQUEST, buildToolHitlPayload(requestId, event))) {
|
||||||
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (isAsyncToolEvent(event.getEventType())) {
|
||||||
|
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, asyncToolChatType(event), buildAsyncToolEventPayload(event))) {
|
||||||
|
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (event.getEventType() == AgentRuntimeEventType.TOOL_CALL) {
|
if (event.getEventType() == AgentRuntimeEventType.TOOL_CALL) {
|
||||||
LOG.info("Agent runtime tool call, requestId={}, toolCallId={}, payload={}, metadata={}",
|
LOG.info("Agent runtime tool call, requestId={}, toolCallId={}, payload={}, metadata={}",
|
||||||
requestId, event.getToolCallId(), event.getPayload(), event.getMetadata());
|
requestId, event.getToolCallId(), event.getPayload(), event.getMetadata());
|
||||||
|
Map<String, Object> toolPayload = buildToolEventPayload(event);
|
||||||
assistantAccumulator.appendToolCall(
|
assistantAccumulator.appendToolCall(
|
||||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
firstText(stringValue(toolPayload, "toolCallId"), event.getToolCallId()),
|
||||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
firstText(stringValue(toolPayload, "toolName"), stringValue(toolPayload, "name")),
|
||||||
firstNonNull(event.getPayload().get("input"), event.getPayload().get("toolInput"))
|
stringValue(toolPayload, "toolDisplayName"),
|
||||||
|
firstNonNull(toolPayload.get("input"), toolPayload.get("toolInput"))
|
||||||
);
|
);
|
||||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_CALL, buildToolEventPayload(event))) {
|
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_CALL, toolPayload)) {
|
||||||
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -462,13 +578,15 @@ public class AgentRunService {
|
|||||||
if (event.getEventType() == AgentRuntimeEventType.TOOL_RESULT) {
|
if (event.getEventType() == AgentRuntimeEventType.TOOL_RESULT) {
|
||||||
LOG.info("Agent runtime tool result, requestId={}, toolCallId={}, payload={}, metadata={}",
|
LOG.info("Agent runtime tool result, requestId={}, toolCallId={}, payload={}, metadata={}",
|
||||||
requestId, event.getToolCallId(), event.getPayload(), event.getMetadata());
|
requestId, event.getToolCallId(), event.getPayload(), event.getMetadata());
|
||||||
|
Map<String, Object> toolPayload = buildToolEventPayload(event);
|
||||||
assistantAccumulator.appendToolResult(
|
assistantAccumulator.appendToolResult(
|
||||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
firstText(stringValue(toolPayload, "toolCallId"), event.getToolCallId()),
|
||||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
firstText(stringValue(toolPayload, "toolName"), stringValue(toolPayload, "name")),
|
||||||
firstNonNull(firstNonNull(event.getPayload().get("output"), event.getPayload().get("result")),
|
stringValue(toolPayload, "toolDisplayName"),
|
||||||
event.getPayload().get("text"))
|
firstNonNull(firstNonNull(toolPayload.get("output"), toolPayload.get("result")),
|
||||||
|
toolPayload.get("text"))
|
||||||
);
|
);
|
||||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_RESULT, buildToolEventPayload(event))) {
|
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_RESULT, toolPayload)) {
|
||||||
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -579,7 +697,7 @@ public class AgentRunService {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
agentRunRegistry.remove(requestId);
|
agentRunRegistry.remove(requestId);
|
||||||
cancelPending(requestId, safeErrorMessage(error));
|
cancelPending(requestId, safeErrorMessage(error), persistChatlog);
|
||||||
Throwable safeError = error == null ? new BusinessException("Agent 运行失败") : error;
|
Throwable safeError = error == null ? new BusinessException("Agent 运行失败") : error;
|
||||||
LOG.error("Agent run failed, requestId={}, message={}, exception={}", requestId,
|
LOG.error("Agent run failed, requestId={}, message={}, exception={}", requestId,
|
||||||
safeError.getMessage(), safeError.toString(), safeError);
|
safeError.getMessage(), safeError.toString(), safeError);
|
||||||
@@ -613,7 +731,7 @@ public class AgentRunService {
|
|||||||
}
|
}
|
||||||
agentRunRegistry.remove(requestId);
|
agentRunRegistry.remove(requestId);
|
||||||
String reason = errorMessage(event);
|
String reason = errorMessage(event);
|
||||||
cancelPending(requestId, reason);
|
cancelPending(requestId, reason, persistChatlog);
|
||||||
LOG.info("Agent run cancelled, requestId={}, reason={}", requestId, reason);
|
LOG.info("Agent run cancelled, requestId={}, reason={}", requestId, reason);
|
||||||
if (persistChatlog) {
|
if (persistChatlog) {
|
||||||
recordPartialAssistantIfPresent(chatContext, answer, assistantAccumulator, reason);
|
recordPartialAssistantIfPresent(chatContext, answer, assistantAccumulator, reason);
|
||||||
@@ -1071,9 +1189,81 @@ public class AgentRunService {
|
|||||||
if (toolCallId != null && !toolCallId.isBlank()) {
|
if (toolCallId != null && !toolCallId.isBlank()) {
|
||||||
payload.put("toolCallId", toolCallId);
|
payload.put("toolCallId", toolCallId);
|
||||||
}
|
}
|
||||||
|
if (Boolean.TRUE.equals(event.getMetadata().get("asyncTool"))) {
|
||||||
|
enrichAsyncToolPayload(payload, event.getMetadata(), toolCallId);
|
||||||
|
String taskId = stringValue(payload, "taskId");
|
||||||
|
if (taskId != null && !taskId.isBlank()) {
|
||||||
|
payload.put("toolCallId", taskId);
|
||||||
|
}
|
||||||
|
}
|
||||||
return payload;
|
return payload;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean isAsyncToolEvent(AgentRuntimeEventType type) {
|
||||||
|
return type == AgentRuntimeEventType.ASYNC_TOOL_SUBMITTED
|
||||||
|
|| type == AgentRuntimeEventType.ASYNC_TOOL_OBSERVED
|
||||||
|
|| type == AgentRuntimeEventType.ASYNC_TOOL_RESULT
|
||||||
|
|| type == AgentRuntimeEventType.ASYNC_TOOL_CANCELLED
|
||||||
|
|| type == AgentRuntimeEventType.ASYNC_TOOL_LISTED
|
||||||
|
|| type == AgentRuntimeEventType.ASYNC_TOOL_FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChatType asyncToolChatType(AgentRuntimeEvent event) {
|
||||||
|
String status = stringPayload(event, "status");
|
||||||
|
if ("SUCCEEDED".equalsIgnoreCase(status)
|
||||||
|
|| "FAILED".equalsIgnoreCase(status)
|
||||||
|
|| "CANCELLED".equalsIgnoreCase(status)
|
||||||
|
|| "TIMEOUT".equalsIgnoreCase(status)
|
||||||
|
|| event.getEventType() == AgentRuntimeEventType.ASYNC_TOOL_RESULT
|
||||||
|
|| event.getEventType() == AgentRuntimeEventType.ASYNC_TOOL_FAILED
|
||||||
|
|| event.getEventType() == AgentRuntimeEventType.ASYNC_TOOL_CANCELLED) {
|
||||||
|
return ChatType.TOOL_RESULT;
|
||||||
|
}
|
||||||
|
return ChatType.TOOL_CALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> buildAsyncToolEventPayload(AgentRuntimeEvent event) {
|
||||||
|
Map<String, Object> payload = new LinkedHashMap<>(event.getPayload() == null ? Map.of() : event.getPayload());
|
||||||
|
String taskId = stringValue(payload, "taskId");
|
||||||
|
String toolCallId = firstText(taskId, event.getToolCallId());
|
||||||
|
if (toolCallId != null && !toolCallId.isBlank()) {
|
||||||
|
payload.put("toolCallId", toolCallId);
|
||||||
|
}
|
||||||
|
enrichAsyncToolPayload(payload, event.getMetadata(), toolCallId);
|
||||||
|
return payload;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void enrichAsyncToolPayload(Map<String, Object> payload, Map<String, Object> metadata, String fallbackId) {
|
||||||
|
Map<String, Object> safeMetadata = metadata == null ? Map.of() : metadata;
|
||||||
|
payload.put("asyncTool", true);
|
||||||
|
putIfPresent(payload, "asyncToolName", firstText(stringValue(payload, "asyncToolName"), stringValue(safeMetadata, "asyncToolName")));
|
||||||
|
putIfPresent(payload, "phase", firstText(stringValue(payload, "phase"), stringValue(safeMetadata, "asyncToolPhase")));
|
||||||
|
putIfPresent(payload, "taskId", firstText(stringValue(payload, "taskId"), stringValue(safeMetadata, "taskId")));
|
||||||
|
putIfPresent(payload, "status", firstText(stringValue(payload, "status"), stringValue(safeMetadata, "status")));
|
||||||
|
String displayName = firstText(stringValue(payload, "toolDisplayName"),
|
||||||
|
firstText(stringValue(safeMetadata, "toolDisplayName"), stringValue(payload, "asyncToolName")));
|
||||||
|
putIfPresent(payload, "toolDisplayName", displayName);
|
||||||
|
putIfPresent(payload, "toolName", displayName);
|
||||||
|
putIfPresent(payload, "name", displayName);
|
||||||
|
String statusKey = "async-tool:" + firstText(stringValue(payload, "taskId"), fallbackId);
|
||||||
|
payload.put("statusKey", statusKey);
|
||||||
|
payload.put("label", asyncToolLabel(stringValue(payload, "status"), stringValue(payload, "phase"), displayName));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String asyncToolLabel(String status, String phase, String displayName) {
|
||||||
|
String name = displayName == null || displayName.isBlank() ? "异步工具" : displayName;
|
||||||
|
if ("SUCCEEDED".equalsIgnoreCase(status)) {
|
||||||
|
return name + "已完成";
|
||||||
|
}
|
||||||
|
if ("FAILED".equalsIgnoreCase(status) || "TIMEOUT".equalsIgnoreCase(status)) {
|
||||||
|
return name + "执行失败";
|
||||||
|
}
|
||||||
|
if ("PENDING".equalsIgnoreCase(status) || "submit".equalsIgnoreCase(phase)) {
|
||||||
|
return name + "已提交";
|
||||||
|
}
|
||||||
|
return name + "执行中";
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 构建知识库检索状态载荷,确保前端可按稳定 key 合并同一轮状态行。
|
* 构建知识库检索状态载荷,确保前端可按稳定 key 合并同一轮状态行。
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import com.easyagents.agent.runtime.knowledge.AgentKnowledgeSpec;
|
|||||||
import com.easyagents.agent.runtime.memory.AgentMemoryCompressionParameter;
|
import com.easyagents.agent.runtime.memory.AgentMemoryCompressionParameter;
|
||||||
import com.easyagents.agent.runtime.memory.AgentMemoryPolicy;
|
import com.easyagents.agent.runtime.memory.AgentMemoryPolicy;
|
||||||
import com.easyagents.agent.runtime.memory.AgentMemoryType;
|
import com.easyagents.agent.runtime.memory.AgentMemoryType;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||||
import com.easyagents.agent.runtime.model.AgentGenerationOptions;
|
import com.easyagents.agent.runtime.model.AgentGenerationOptions;
|
||||||
import com.easyagents.agent.runtime.model.AgentModelProviderType;
|
import com.easyagents.agent.runtime.model.AgentModelProviderType;
|
||||||
import com.easyagents.agent.runtime.model.AgentModelSpec;
|
import com.easyagents.agent.runtime.model.AgentModelSpec;
|
||||||
@@ -27,8 +29,9 @@ import tech.easyflow.agent.entity.Agent;
|
|||||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
import tech.easyflow.agent.enums.AgentToolType;
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolRuntimeCompilation;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolRuntimeCompiler;
|
||||||
import tech.easyflow.ai.easyagents.tool.ChatToolNameHelper;
|
import tech.easyflow.ai.easyagents.tool.ChatToolNameHelper;
|
||||||
import tech.easyflow.ai.easyagents.tool.McpTool;
|
|
||||||
import tech.easyflow.ai.easyagents.tool.WorkflowTool;
|
import tech.easyflow.ai.easyagents.tool.WorkflowTool;
|
||||||
import tech.easyflow.ai.easyagentsflow.support.PublishedWorkflowDefinitionIds;
|
import tech.easyflow.ai.easyagentsflow.support.PublishedWorkflowDefinitionIds;
|
||||||
import tech.easyflow.ai.entity.*;
|
import tech.easyflow.ai.entity.*;
|
||||||
@@ -40,16 +43,19 @@ import tech.easyflow.common.web.exceptions.BusinessException;
|
|||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 将 Agent 发布快照编译为 easy-agents-agent-runtime 可执行定义。
|
* 将 Agent 发布快照编译为可执行定义。
|
||||||
*/
|
*/
|
||||||
@Component
|
@Component
|
||||||
public class AgentDefinitionCompiler {
|
public class AgentRuntimeCompiler {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(AgentDefinitionCompiler.class);
|
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeCompiler.class);
|
||||||
private static final int LOG_TEXT_MAX_LENGTH = 500;
|
private static final int LOG_TEXT_MAX_LENGTH = 500;
|
||||||
|
private static final Pattern MCP_INPUT_PATTERN = Pattern.compile("\\$\\{input:([A-Za-z0-9_.-]+)}");
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private ModelService modelService;
|
private ModelService modelService;
|
||||||
@@ -63,6 +69,8 @@ public class AgentDefinitionCompiler {
|
|||||||
private DocumentCollectionService documentCollectionService;
|
private DocumentCollectionService documentCollectionService;
|
||||||
@Resource
|
@Resource
|
||||||
private ObjectMapper objectMapper;
|
private ObjectMapper objectMapper;
|
||||||
|
@Resource
|
||||||
|
private AgentToolRuntimeCompiler agentToolRuntimeCompiler;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 编译 Agent 运行时定义和调用器。
|
* 编译 Agent 运行时定义和调用器。
|
||||||
@@ -205,22 +213,10 @@ public class AgentDefinitionCompiler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void compileTools(Agent agent, AgentDefinition definition, AgentRuntimeBundle bundle) {
|
private void compileTools(Agent agent, AgentDefinition definition, AgentRuntimeBundle bundle) {
|
||||||
if (agent.getToolBindings() == null) {
|
AgentToolRuntimeCompilation compilation = agentToolRuntimeCompiler.compile(agent);
|
||||||
return;
|
definition.setToolSpecs(compilation.getToolSpecs());
|
||||||
}
|
definition.setMcpSpecs(compilation.getMcpSpecs());
|
||||||
List<AgentToolSpec> specs = new ArrayList<>();
|
bundle.setToolInvokers(compilation.getToolInvokers());
|
||||||
Map<String, com.easyagents.agent.runtime.tool.AgentToolInvoker> invokers = new LinkedHashMap<>();
|
|
||||||
for (AgentToolBinding binding : agent.getToolBindings()) {
|
|
||||||
if (!Boolean.TRUE.equals(binding.getEnabled())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Tool tool = buildTool(binding);
|
|
||||||
AgentToolSpec spec = toToolSpec(tool, binding);
|
|
||||||
specs.add(spec);
|
|
||||||
invokers.put(spec.getName(), (arguments, context) -> invokeTool(tool, arguments));
|
|
||||||
}
|
|
||||||
definition.setToolSpecs(specs);
|
|
||||||
bundle.setToolInvokers(invokers);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Tool buildTool(AgentToolBinding binding) {
|
private Tool buildTool(AgentToolBinding binding) {
|
||||||
@@ -243,16 +239,74 @@ public class AgentDefinitionCompiler {
|
|||||||
}
|
}
|
||||||
return pluginItem.toFunction();
|
return pluginItem.toFunction();
|
||||||
}
|
}
|
||||||
|
throw new BusinessException("不支持的 Agent 工具类型:" + type.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
private McpSpec buildMcpSpec(AgentToolBinding binding) {
|
||||||
Mcp mcp = snapshotOrCurrentMcp(binding);
|
Mcp mcp = snapshotOrCurrentMcp(binding);
|
||||||
if (mcp == null) {
|
if (mcp == null) {
|
||||||
throw new BusinessException("绑定 MCP 不存在");
|
throw new BusinessException("绑定 MCP 不存在");
|
||||||
}
|
}
|
||||||
McpTool tool = new McpTool();
|
Map.Entry<String, Map<String, Object>> server = firstMcpServer(mcp);
|
||||||
tool.setMcpId(mcp.getId());
|
Map<String, Object> serverConfig = server.getValue();
|
||||||
tool.setName(binding.getToolName());
|
McpTransportType transportType = parseMcpTransportType(mcp, serverConfig);
|
||||||
tool.setDescription(mcp.getDescription());
|
|
||||||
tool.setParameters(new Parameter[0]);
|
McpSpec spec = new McpSpec();
|
||||||
return tool;
|
spec.setName(mcpRuntimeName(mcp));
|
||||||
|
spec.setDescription(firstNonBlank(mcp.getDescription(), mcp.getTitle()));
|
||||||
|
spec.setTransportType(transportType);
|
||||||
|
spec.setCommand(resolveMcpInput(stringValue(serverConfig, "command", null)));
|
||||||
|
spec.setArgs(resolveMcpInputs(stringListValue(serverConfig, "args")));
|
||||||
|
spec.setEnv(resolveMcpInputMap(stringMapValue(serverConfig, "env")));
|
||||||
|
spec.setUrl(resolveMcpInput(stringValue(serverConfig, "url", null)));
|
||||||
|
spec.setHeaders(resolveMcpInputMap(stringMapValue(serverConfig, "headers")));
|
||||||
|
spec.setQueryParams(resolveMcpInputMap(stringMapValue(serverConfig, "queryParams")));
|
||||||
|
Duration timeout = durationValue(serverConfig, "timeout");
|
||||||
|
if (timeout != null) {
|
||||||
|
spec.setTimeout(timeout);
|
||||||
|
}
|
||||||
|
Duration initializationTimeout = durationValue(serverConfig, "initializationTimeout");
|
||||||
|
if (initializationTimeout != null) {
|
||||||
|
spec.setInitializationTimeout(initializationTimeout);
|
||||||
|
}
|
||||||
|
spec.setGroupName(mcpRuntimeName(mcp));
|
||||||
|
spec.setApprovalRequired(Boolean.TRUE.equals(mcp.getApprovalRequired()));
|
||||||
|
spec.setApprovalRequest(buildMcpApprovalRequest(mcp));
|
||||||
|
spec.setToolNamePrefix(mcpRuntimeToolPrefix(mcp.getId()));
|
||||||
|
spec.getMetadata().put("toolType", AgentToolType.MCP.name());
|
||||||
|
spec.getMetadata().put("mcpId", String.valueOf(mcp.getId()));
|
||||||
|
spec.getMetadata().put("mcpTitle", mcp.getTitle());
|
||||||
|
spec.getMetadata().put("serverName", server.getKey());
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void applyMcpToolBinding(McpSpec spec, AgentToolBinding binding) {
|
||||||
|
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||||
|
spec.setApprovalRequired(true);
|
||||||
|
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolApprovalRequest buildMcpApprovalRequest(Mcp mcp) {
|
||||||
|
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||||
|
request.setApprovalPrompt("是否批准执行 MCP 工具:" + firstNonBlank(mcp.getTitle(), mcpRuntimeName(mcp)));
|
||||||
|
Map<String, Object> metadata = new LinkedHashMap<>();
|
||||||
|
metadata.put("toolType", AgentToolType.MCP.name());
|
||||||
|
metadata.put("mcpId", String.valueOf(mcp.getId()));
|
||||||
|
metadata.put("mcpTitle", mcp.getTitle());
|
||||||
|
request.setMetadata(metadata);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolApprovalRequest buildBindingApprovalRequest(AgentToolBinding binding) {
|
||||||
|
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||||
|
request.setApprovalPrompt(stringValue(binding.getHitlConfigJson(), "prompt", "是否批准执行 MCP 工具"));
|
||||||
|
Map<String, Object> metadata = sanitizedHitlMetadata(binding.getHitlConfigJson());
|
||||||
|
metadata.put("toolType", binding.getToolType());
|
||||||
|
metadata.put("bindingId", binding.getId());
|
||||||
|
metadata.put("targetId", binding.getTargetId());
|
||||||
|
request.setMetadata(metadata);
|
||||||
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
private AgentToolSpec toToolSpec(Tool tool, AgentToolBinding binding) {
|
private AgentToolSpec toToolSpec(Tool tool, AgentToolBinding binding) {
|
||||||
@@ -477,6 +531,138 @@ public class AgentDefinitionCompiler {
|
|||||||
return mcpService.getById(binding.getTargetId());
|
return mcpService.getById(binding.getTargetId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Map.Entry<String, Map<String, Object>> firstMcpServer(Mcp mcp) {
|
||||||
|
Map<String, Object> config = parseMcpConfig(mcp);
|
||||||
|
Map<String, Object> servers = mapValue(config, "mcpServers");
|
||||||
|
if (servers.isEmpty()) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 中没有找到任何 MCP 服务名称");
|
||||||
|
}
|
||||||
|
Map.Entry<String, Object> first = servers.entrySet().iterator().next();
|
||||||
|
if (!(first.getValue() instanceof Map<?, ?> rawServer)) {
|
||||||
|
throw new BusinessException("MCP 服务配置必须是对象:" + first.getKey());
|
||||||
|
}
|
||||||
|
Map<String, Object> serverConfig = new LinkedHashMap<>();
|
||||||
|
rawServer.forEach((key, value) -> serverConfig.put(String.valueOf(key), value));
|
||||||
|
return Map.entry(first.getKey(), serverConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> parseMcpConfig(Mcp mcp) {
|
||||||
|
String configJson = mcp == null ? null : mcp.getConfigJson();
|
||||||
|
if (configJson == null || configJson.isBlank()) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 不能为空");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return objectMapper.readValue(configJson, new com.fasterxml.jackson.core.type.TypeReference<>() {});
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 格式错误");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private McpTransportType parseMcpTransportType(Mcp mcp, Map<String, Object> serverConfig) {
|
||||||
|
String transport = firstNonBlank(
|
||||||
|
mcp == null ? null : mcp.getTransportType(),
|
||||||
|
stringValue(serverConfig, "transport", null)
|
||||||
|
);
|
||||||
|
return McpTransportType.from(transport);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String mcpRuntimeName(Mcp mcp) {
|
||||||
|
BigInteger id = mcp == null ? null : mcp.getId();
|
||||||
|
return "mcp_" + safeToolNameSegment(id == null ? "unknown" : String.valueOf(id));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String mcpRuntimeToolPrefix(BigInteger mcpId) {
|
||||||
|
return "mcp_" + safeToolNameSegment(String.valueOf(mcpId)) + "_";
|
||||||
|
}
|
||||||
|
|
||||||
|
private String safeToolNameSegment(String value) {
|
||||||
|
String normalized = String.valueOf(value == null ? "" : value).trim()
|
||||||
|
.replaceAll("[^A-Za-z0-9_-]", "_")
|
||||||
|
.replaceAll("_+", "_");
|
||||||
|
if (normalized.isBlank()) {
|
||||||
|
return "tool";
|
||||||
|
}
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> stringListValue(Map<String, Object> map, String key) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
if (value instanceof Collection<?> collection) {
|
||||||
|
List<String> result = new ArrayList<>();
|
||||||
|
for (Object item : collection) {
|
||||||
|
if (item != null) {
|
||||||
|
result.add(String.valueOf(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
throw new BusinessException("Agent 配置字段必须是数组:" + key);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Duration durationValue(Map<String, Object> map, String key) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (value instanceof Number number) {
|
||||||
|
return Duration.ofSeconds(number.longValue());
|
||||||
|
}
|
||||||
|
String text = String.valueOf(value).trim();
|
||||||
|
if (text.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return Duration.parse(text);
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
try {
|
||||||
|
return Duration.ofSeconds(Long.parseLong(text));
|
||||||
|
} catch (NumberFormatException e) {
|
||||||
|
throw new BusinessException("Agent 配置字段必须是秒数或 Duration:" + key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> resolveMcpInputs(List<String> values) {
|
||||||
|
if (values == null || values.isEmpty()) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
List<String> result = new ArrayList<>(values.size());
|
||||||
|
for (String value : values) {
|
||||||
|
result.add(resolveMcpInput(value));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> resolveMcpInputMap(Map<String, String> values) {
|
||||||
|
if (values == null || values.isEmpty()) {
|
||||||
|
return new LinkedHashMap<>();
|
||||||
|
}
|
||||||
|
Map<String, String> result = new LinkedHashMap<>();
|
||||||
|
values.forEach((key, value) -> result.put(key, resolveMcpInput(value)));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveMcpInput(String value) {
|
||||||
|
if (value == null || value.isBlank()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
Matcher matcher = MCP_INPUT_PATTERN.matcher(value);
|
||||||
|
StringBuffer resolved = new StringBuffer();
|
||||||
|
while (matcher.find()) {
|
||||||
|
String inputKey = matcher.group(1);
|
||||||
|
String resolvedValue = System.getProperty("mcp.input." + inputKey);
|
||||||
|
if (resolvedValue == null || resolvedValue.isBlank()) {
|
||||||
|
throw new BusinessException("MCP 输入变量未解析:" + inputKey);
|
||||||
|
}
|
||||||
|
matcher.appendReplacement(resolved, Matcher.quoteReplacement(resolvedValue));
|
||||||
|
}
|
||||||
|
matcher.appendTail(resolved);
|
||||||
|
return resolved.toString();
|
||||||
|
}
|
||||||
|
|
||||||
private DocumentCollection snapshotOrPublishedKnowledge(AgentKnowledgeBinding binding) {
|
private DocumentCollection snapshotOrPublishedKnowledge(AgentKnowledgeBinding binding) {
|
||||||
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||||
DocumentCollection knowledge = objectMapper.convertValue(binding.getResourceSnapshot(), DocumentCollection.class);
|
DocumentCollection knowledge = objectMapper.convertValue(binding.getResourceSnapshot(), DocumentCollection.class);
|
||||||
@@ -0,0 +1,310 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolContext;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.*;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolExecutionResult;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EasyFlow Agent 异步业务工具基类。
|
||||||
|
*/
|
||||||
|
public abstract class AbstractAgentAsyncSubTools implements AsyncSubTools {
|
||||||
|
|
||||||
|
private static final String ERROR_TYPE_NOT_FOUND = "TASK_NOT_FOUND";
|
||||||
|
private static final String ERROR_TYPE_EXCEPTION = "EXCEPTION";
|
||||||
|
|
||||||
|
private final AgentAsyncToolTaskStore taskStore;
|
||||||
|
private final ThreadPoolTaskExecutor taskExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建异步业务工具基类。
|
||||||
|
*
|
||||||
|
* @param taskStore 任务存储
|
||||||
|
* @param taskExecutor 后台执行器
|
||||||
|
*/
|
||||||
|
protected AbstractAgentAsyncSubTools(AgentAsyncToolTaskStore taskStore,
|
||||||
|
ThreadPoolTaskExecutor taskExecutor) {
|
||||||
|
this.taskStore = taskStore;
|
||||||
|
this.taskExecutor = taskExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取工具类型。
|
||||||
|
*
|
||||||
|
* @return 工具类型
|
||||||
|
*/
|
||||||
|
protected abstract String toolType();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取运行时工具名。
|
||||||
|
*
|
||||||
|
* @return 运行时工具名
|
||||||
|
*/
|
||||||
|
protected abstract String toolName();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取用户可见工具名称。
|
||||||
|
*
|
||||||
|
* @return 用户可见工具名称
|
||||||
|
*/
|
||||||
|
protected abstract String displayName();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务资源 ID。
|
||||||
|
*
|
||||||
|
* @return 业务资源 ID
|
||||||
|
*/
|
||||||
|
protected abstract String businessId();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行业务工具。
|
||||||
|
*
|
||||||
|
* @param arguments 调用参数
|
||||||
|
* @return 执行结果
|
||||||
|
*/
|
||||||
|
protected abstract AgentToolExecutionResult executeBusiness(Map<String, Object> arguments);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public AsyncToolSubmitResult submit(Map<String, Object> arguments, AgentToolContext context) {
|
||||||
|
String sessionId = requireSessionId(context);
|
||||||
|
String taskId = newTaskId();
|
||||||
|
AgentAsyncToolTaskRecord record = new AgentAsyncToolTaskRecord();
|
||||||
|
record.setTaskId(taskId);
|
||||||
|
record.setToolType(toolType());
|
||||||
|
record.setToolName(toolName());
|
||||||
|
record.setBusinessId(businessId());
|
||||||
|
record.setStatus(AsyncToolTaskStatus.PENDING);
|
||||||
|
record.setArguments(arguments == null ? Map.of() : new LinkedHashMap<>(arguments));
|
||||||
|
record.setSummary(displayName() + "任务已提交");
|
||||||
|
record.setRequestId(context == null ? null : context.getRequestId());
|
||||||
|
record.setTraceId(context == null ? null : context.getTraceId());
|
||||||
|
record.setSessionId(sessionId);
|
||||||
|
record.setAgentId(context == null ? null : context.getAgentId());
|
||||||
|
record.setToolCallId(context == null ? null : context.getToolCallId());
|
||||||
|
record.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
appendEvent(record, "SUBMITTED", displayName() + "任务已提交");
|
||||||
|
taskStore.create(record);
|
||||||
|
dispatch(sessionId, record.getTaskId(), record.getArguments());
|
||||||
|
|
||||||
|
AsyncToolSubmitResult result = new AsyncToolSubmitResult();
|
||||||
|
result.setTaskId(taskId);
|
||||||
|
result.setStatus(AsyncToolTaskStatus.PENDING);
|
||||||
|
result.setCursor(0L);
|
||||||
|
result.setSummary(record.getSummary());
|
||||||
|
result.setNextAction(toolName() + "_observe 查看任务进度。");
|
||||||
|
result.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public AsyncToolTaskView observe(AsyncToolObserveRequest request, AgentToolContext context) {
|
||||||
|
return taskView(request == null ? null : request.getTaskId(),
|
||||||
|
request == null ? null : request.getCursor(),
|
||||||
|
request == null ? null : request.getLimit(),
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public AsyncToolTaskView result(AsyncToolResultRequest request, AgentToolContext context) {
|
||||||
|
return taskView(request == null ? null : request.getTaskId(),
|
||||||
|
request == null ? null : request.getCursor(),
|
||||||
|
request == null ? null : request.getLimit(),
|
||||||
|
context);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public AsyncToolCancelResult cancel(AsyncToolCancelRequest request, AgentToolContext context) {
|
||||||
|
AsyncToolCancelResult result = new AsyncToolCancelResult();
|
||||||
|
result.setTaskId(request == null ? null : request.getTaskId());
|
||||||
|
result.setStatus(AsyncToolTaskStatus.FAILED);
|
||||||
|
result.setErrorMessage("当前异步工具不支持取消正在执行的任务");
|
||||||
|
result.setMessage("不支持取消");
|
||||||
|
result.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public AsyncToolTaskListResult list(AsyncToolListRequest request, AgentToolContext context) {
|
||||||
|
String sessionId = requireSessionId(context);
|
||||||
|
AsyncToolTaskStatus status = request == null ? null : request.getStatus();
|
||||||
|
List<AsyncToolTaskSummary> tasks = new ArrayList<>();
|
||||||
|
for (AgentAsyncToolTaskRecord record : taskStore.list(sessionId, status)) {
|
||||||
|
tasks.add(summary(record));
|
||||||
|
}
|
||||||
|
AsyncToolTaskListResult result = new AsyncToolTaskListResult();
|
||||||
|
result.setTasks(tasks);
|
||||||
|
result.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void dispatch(String sessionId, String taskId, Map<String, Object> arguments) {
|
||||||
|
try {
|
||||||
|
taskExecutor.execute(() -> executeTask(sessionId, taskId, arguments));
|
||||||
|
} catch (Exception e) {
|
||||||
|
taskStore.update(sessionId, taskId, record -> fail(record, e));
|
||||||
|
throw new BusinessException("提交异步工具任务失败:" + safeMessage(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void executeTask(String sessionId, String taskId, Map<String, Object> arguments) {
|
||||||
|
try {
|
||||||
|
taskStore.update(sessionId, taskId, record -> {
|
||||||
|
record.setStatus(AsyncToolTaskStatus.RUNNING);
|
||||||
|
record.setSummary(displayName() + "任务执行中");
|
||||||
|
appendEvent(record, "RUNNING", displayName() + "任务执行中");
|
||||||
|
return record;
|
||||||
|
});
|
||||||
|
AgentToolExecutionResult executionResult = executeBusiness(arguments);
|
||||||
|
taskStore.update(sessionId, taskId, record -> {
|
||||||
|
record.setStatus(AsyncToolTaskStatus.SUCCEEDED);
|
||||||
|
record.setSummary(displayName() + "任务已完成");
|
||||||
|
record.setResult(executionResult == null ? null : executionResult.getResult());
|
||||||
|
record.setBusinessExecutionId(executionResult == null ? null : executionResult.getBusinessExecutionId());
|
||||||
|
appendEvent(record, "SUCCEEDED", displayName() + "任务已完成");
|
||||||
|
return record;
|
||||||
|
});
|
||||||
|
} catch (Exception e) {
|
||||||
|
taskStore.update(sessionId, taskId, record -> fail(record, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentAsyncToolTaskRecord fail(AgentAsyncToolTaskRecord record, Exception error) {
|
||||||
|
record.setStatus(AsyncToolTaskStatus.FAILED);
|
||||||
|
record.setSummary(displayName() + "任务执行失败");
|
||||||
|
record.setErrorType(ERROR_TYPE_EXCEPTION);
|
||||||
|
record.setErrorMessage(safeMessage(error));
|
||||||
|
appendEvent(record, "FAILED", record.getErrorMessage());
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView taskView(String taskId, Long cursor, Integer limit, AgentToolContext context) {
|
||||||
|
String sessionId = requireSessionId(context);
|
||||||
|
if (!StringUtils.hasText(taskId)) {
|
||||||
|
return notFoundView(taskId, cursor, "任务 ID 不能为空");
|
||||||
|
}
|
||||||
|
return taskStore.get(sessionId, taskId)
|
||||||
|
.map(record -> toView(record, cursor, limit))
|
||||||
|
.orElseGet(() -> notFoundView(taskId, cursor, "异步工具任务不存在或已过期"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView toView(AgentAsyncToolTaskRecord record, Long cursor, Integer limit) {
|
||||||
|
long safeCursor = cursor == null ? 0L : Math.max(0L, cursor);
|
||||||
|
int safeLimit = limit == null || limit <= 0 ? 20 : Math.min(limit, 100);
|
||||||
|
List<AsyncToolTaskEvent> events = new ArrayList<>();
|
||||||
|
for (AsyncToolTaskEvent event : record.getEvents()) {
|
||||||
|
if (event.getSequence() != null && event.getSequence() > safeCursor) {
|
||||||
|
events.add(event);
|
||||||
|
}
|
||||||
|
if (events.size() >= safeLimit) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Long nextCursor = events.isEmpty()
|
||||||
|
? safeCursor
|
||||||
|
: events.get(events.size() - 1).getSequence();
|
||||||
|
AsyncToolTaskView view = new AsyncToolTaskView();
|
||||||
|
view.setTaskId(record.getTaskId());
|
||||||
|
view.setStatus(record.getStatus());
|
||||||
|
view.setCursor(safeCursor);
|
||||||
|
view.setNextCursor(nextCursor);
|
||||||
|
view.setSummary(record.getSummary());
|
||||||
|
view.setNextAction(nextAction(record.getStatus()));
|
||||||
|
view.setEvents(events);
|
||||||
|
view.setResult(record.getResult());
|
||||||
|
view.setErrorMessage(record.getErrorMessage());
|
||||||
|
view.setErrorType(record.getErrorType());
|
||||||
|
view.setTerminal(record.getStatus() != null && record.getStatus().isTerminal());
|
||||||
|
view.setResultAvailable(record.getStatus() == AsyncToolTaskStatus.SUCCEEDED && record.getResult() != null);
|
||||||
|
view.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
putIfNotNull(view.getPayload(), "businessId", record.getBusinessId());
|
||||||
|
putIfNotNull(view.getPayload(), "businessExecutionId", record.getBusinessExecutionId());
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView notFoundView(String taskId, Long cursor, String message) {
|
||||||
|
AsyncToolTaskView view = new AsyncToolTaskView();
|
||||||
|
view.setTaskId(taskId);
|
||||||
|
view.setStatus(AsyncToolTaskStatus.FAILED);
|
||||||
|
view.setCursor(cursor == null ? 0L : cursor);
|
||||||
|
view.setNextCursor(cursor == null ? 0L : cursor);
|
||||||
|
view.setSummary(message);
|
||||||
|
view.setErrorType(ERROR_TYPE_NOT_FOUND);
|
||||||
|
view.setErrorMessage(message);
|
||||||
|
view.setTerminal(true);
|
||||||
|
view.setResultAvailable(false);
|
||||||
|
view.getMetadata().put("toolDisplayName", displayName());
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskSummary summary(AgentAsyncToolTaskRecord record) {
|
||||||
|
AsyncToolTaskSummary summary = new AsyncToolTaskSummary();
|
||||||
|
summary.setTaskId(record.getTaskId());
|
||||||
|
summary.setStatus(record.getStatus());
|
||||||
|
summary.setSummary(record.getSummary());
|
||||||
|
summary.setCreatedAt(record.getCreatedAt());
|
||||||
|
summary.setUpdatedAt(record.getUpdatedAt());
|
||||||
|
summary.getPayload().put("toolName", record.getToolName());
|
||||||
|
summary.getPayload().put("toolDisplayName", displayName());
|
||||||
|
return summary;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendEvent(AgentAsyncToolTaskRecord record, String type, String text) {
|
||||||
|
AsyncToolTaskEvent event = new AsyncToolTaskEvent();
|
||||||
|
event.setSequence((long) record.getEvents().size() + 1L);
|
||||||
|
event.setType(type);
|
||||||
|
event.setText(text);
|
||||||
|
event.setCreatedAt(Instant.now());
|
||||||
|
record.getEvents().add(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String nextAction(AsyncToolTaskStatus status) {
|
||||||
|
if (status != null && status.isTerminal()) {
|
||||||
|
return "任务已结束。";
|
||||||
|
}
|
||||||
|
return toolName() + "_observe 继续查看任务进度。";
|
||||||
|
}
|
||||||
|
|
||||||
|
private String requireSessionId(AgentToolContext context) {
|
||||||
|
if (context == null || !StringUtils.hasText(context.getSessionId())) {
|
||||||
|
throw new BusinessException("异步工具任务缺少 Agent session 上下文");
|
||||||
|
}
|
||||||
|
return context.getSessionId();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String newTaskId() {
|
||||||
|
String idPart = UUID.randomUUID().toString().replace("-", "");
|
||||||
|
return "async_" + idPart;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void putIfNotNull(Map<String, Object> target, String key, Object value) {
|
||||||
|
if (value != null) {
|
||||||
|
target.put(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String safeMessage(Exception e) {
|
||||||
|
return e == null || e.getMessage() == null || e.getMessage().isBlank()
|
||||||
|
? "异步工具任务执行失败"
|
||||||
|
: e.getMessage();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskEvent;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskStatus;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 异步工具任务 Redis 运行态记录。
|
||||||
|
*/
|
||||||
|
public class AgentAsyncToolTaskRecord {
|
||||||
|
|
||||||
|
private String taskId;
|
||||||
|
private String toolType;
|
||||||
|
private String toolName;
|
||||||
|
private String businessId;
|
||||||
|
private String businessExecutionId;
|
||||||
|
private String sessionScopedKey;
|
||||||
|
private Long ttlSeconds;
|
||||||
|
private AsyncToolTaskStatus status = AsyncToolTaskStatus.PENDING;
|
||||||
|
private Map<String, Object> arguments = new LinkedHashMap<>();
|
||||||
|
private String summary;
|
||||||
|
private Object result;
|
||||||
|
private String errorMessage;
|
||||||
|
private String errorType;
|
||||||
|
private List<AsyncToolTaskEvent> events = new ArrayList<>();
|
||||||
|
private String requestId;
|
||||||
|
private String traceId;
|
||||||
|
private String sessionId;
|
||||||
|
private String agentId;
|
||||||
|
private String toolCallId;
|
||||||
|
private Instant createdAt = Instant.now();
|
||||||
|
private Instant updatedAt = Instant.now();
|
||||||
|
private Map<String, Object> payload = new LinkedHashMap<>();
|
||||||
|
private Map<String, Object> metadata = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务 ID。
|
||||||
|
*
|
||||||
|
* @return 任务 ID
|
||||||
|
*/
|
||||||
|
public String getTaskId() { return taskId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务 ID。
|
||||||
|
*
|
||||||
|
* @param taskId 任务 ID
|
||||||
|
*/
|
||||||
|
public void setTaskId(String taskId) { this.taskId = taskId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取工具类型。
|
||||||
|
*
|
||||||
|
* @return 工具类型
|
||||||
|
*/
|
||||||
|
public String getToolType() { return toolType; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置工具类型。
|
||||||
|
*
|
||||||
|
* @param toolType 工具类型
|
||||||
|
*/
|
||||||
|
public void setToolType(String toolType) { this.toolType = toolType; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取工具名称。
|
||||||
|
*
|
||||||
|
* @return 工具名称
|
||||||
|
*/
|
||||||
|
public String getToolName() { return toolName; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置工具名称。
|
||||||
|
*
|
||||||
|
* @param toolName 工具名称
|
||||||
|
*/
|
||||||
|
public void setToolName(String toolName) { this.toolName = toolName; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务资源 ID。
|
||||||
|
*
|
||||||
|
* @return 业务资源 ID
|
||||||
|
*/
|
||||||
|
public String getBusinessId() { return businessId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置业务资源 ID。
|
||||||
|
*
|
||||||
|
* @param businessId 业务资源 ID
|
||||||
|
*/
|
||||||
|
public void setBusinessId(String businessId) { this.businessId = businessId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务执行记录 ID。
|
||||||
|
*
|
||||||
|
* @return 业务执行记录 ID
|
||||||
|
*/
|
||||||
|
public String getBusinessExecutionId() { return businessExecutionId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置业务执行记录 ID。
|
||||||
|
*
|
||||||
|
* @param businessExecutionId 业务执行记录 ID
|
||||||
|
*/
|
||||||
|
public void setBusinessExecutionId(String businessExecutionId) { this.businessExecutionId = businessExecutionId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取会话内任务存储 key。
|
||||||
|
*
|
||||||
|
* @return 会话内任务存储 key
|
||||||
|
*/
|
||||||
|
public String getSessionScopedKey() { return sessionScopedKey; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置会话内任务存储 key。
|
||||||
|
*
|
||||||
|
* @param sessionScopedKey 会话内任务存储 key
|
||||||
|
*/
|
||||||
|
public void setSessionScopedKey(String sessionScopedKey) { this.sessionScopedKey = sessionScopedKey; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务 TTL 秒数。
|
||||||
|
*
|
||||||
|
* @return TTL 秒数
|
||||||
|
*/
|
||||||
|
public Long getTtlSeconds() { return ttlSeconds; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务 TTL 秒数。
|
||||||
|
*
|
||||||
|
* @param ttlSeconds TTL 秒数
|
||||||
|
*/
|
||||||
|
public void setTtlSeconds(Long ttlSeconds) { this.ttlSeconds = ttlSeconds; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务状态。
|
||||||
|
*
|
||||||
|
* @return 任务状态
|
||||||
|
*/
|
||||||
|
public AsyncToolTaskStatus getStatus() { return status; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务状态。
|
||||||
|
*
|
||||||
|
* @param status 任务状态
|
||||||
|
*/
|
||||||
|
public void setStatus(AsyncToolTaskStatus status) { this.status = status == null ? AsyncToolTaskStatus.PENDING : status; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务参数。
|
||||||
|
*
|
||||||
|
* @return 任务参数
|
||||||
|
*/
|
||||||
|
public Map<String, Object> getArguments() { return arguments; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务参数。
|
||||||
|
*
|
||||||
|
* @param arguments 任务参数
|
||||||
|
*/
|
||||||
|
public void setArguments(Map<String, Object> arguments) { this.arguments = arguments == null ? new LinkedHashMap<>() : arguments; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务摘要。
|
||||||
|
*
|
||||||
|
* @return 任务摘要
|
||||||
|
*/
|
||||||
|
public String getSummary() { return summary; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务摘要。
|
||||||
|
*
|
||||||
|
* @param summary 任务摘要
|
||||||
|
*/
|
||||||
|
public void setSummary(String summary) { this.summary = summary; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务结果。
|
||||||
|
*
|
||||||
|
* @return 任务结果
|
||||||
|
*/
|
||||||
|
public Object getResult() { return result; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务结果。
|
||||||
|
*
|
||||||
|
* @param result 任务结果
|
||||||
|
*/
|
||||||
|
public void setResult(Object result) { this.result = result; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取错误消息。
|
||||||
|
*
|
||||||
|
* @return 错误消息
|
||||||
|
*/
|
||||||
|
public String getErrorMessage() { return errorMessage; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置错误消息。
|
||||||
|
*
|
||||||
|
* @param errorMessage 错误消息
|
||||||
|
*/
|
||||||
|
public void setErrorMessage(String errorMessage) { this.errorMessage = errorMessage; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取错误类型。
|
||||||
|
*
|
||||||
|
* @return 错误类型
|
||||||
|
*/
|
||||||
|
public String getErrorType() { return errorType; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置错误类型。
|
||||||
|
*
|
||||||
|
* @param errorType 错误类型
|
||||||
|
*/
|
||||||
|
public void setErrorType(String errorType) { this.errorType = errorType; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取任务事件列表。
|
||||||
|
*
|
||||||
|
* @return 任务事件列表
|
||||||
|
*/
|
||||||
|
public List<AsyncToolTaskEvent> getEvents() { return events; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置任务事件列表。
|
||||||
|
*
|
||||||
|
* @param events 任务事件列表
|
||||||
|
*/
|
||||||
|
public void setEvents(List<AsyncToolTaskEvent> events) { this.events = events == null ? new ArrayList<>() : events; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取请求 ID。
|
||||||
|
*
|
||||||
|
* @return 请求 ID
|
||||||
|
*/
|
||||||
|
public String getRequestId() { return requestId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置请求 ID。
|
||||||
|
*
|
||||||
|
* @param requestId 请求 ID
|
||||||
|
*/
|
||||||
|
public void setRequestId(String requestId) { this.requestId = requestId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取链路 ID。
|
||||||
|
*
|
||||||
|
* @return 链路 ID
|
||||||
|
*/
|
||||||
|
public String getTraceId() { return traceId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置链路 ID。
|
||||||
|
*
|
||||||
|
* @param traceId 链路 ID
|
||||||
|
*/
|
||||||
|
public void setTraceId(String traceId) { this.traceId = traceId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent Runtime session ID。
|
||||||
|
*
|
||||||
|
* @return session ID
|
||||||
|
*/
|
||||||
|
public String getSessionId() { return sessionId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent Runtime session ID。
|
||||||
|
*
|
||||||
|
* @param sessionId session ID
|
||||||
|
*/
|
||||||
|
public void setSessionId(String sessionId) { this.sessionId = sessionId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent ID。
|
||||||
|
*
|
||||||
|
* @return Agent ID
|
||||||
|
*/
|
||||||
|
public String getAgentId() { return agentId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent ID。
|
||||||
|
*
|
||||||
|
* @param agentId Agent ID
|
||||||
|
*/
|
||||||
|
public void setAgentId(String agentId) { this.agentId = agentId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取工具调用 ID。
|
||||||
|
*
|
||||||
|
* @return 工具调用 ID
|
||||||
|
*/
|
||||||
|
public String getToolCallId() { return toolCallId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置工具调用 ID。
|
||||||
|
*
|
||||||
|
* @param toolCallId 工具调用 ID
|
||||||
|
*/
|
||||||
|
public void setToolCallId(String toolCallId) { this.toolCallId = toolCallId; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取创建时间。
|
||||||
|
*
|
||||||
|
* @return 创建时间
|
||||||
|
*/
|
||||||
|
public Instant getCreatedAt() { return createdAt; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置创建时间。
|
||||||
|
*
|
||||||
|
* @param createdAt 创建时间
|
||||||
|
*/
|
||||||
|
public void setCreatedAt(Instant createdAt) { this.createdAt = createdAt == null ? Instant.now() : createdAt; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取更新时间。
|
||||||
|
*
|
||||||
|
* @return 更新时间
|
||||||
|
*/
|
||||||
|
public Instant getUpdatedAt() { return updatedAt; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置更新时间。
|
||||||
|
*
|
||||||
|
* @param updatedAt 更新时间
|
||||||
|
*/
|
||||||
|
public void setUpdatedAt(Instant updatedAt) { this.updatedAt = updatedAt == null ? Instant.now() : updatedAt; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务载荷。
|
||||||
|
*
|
||||||
|
* @return 业务载荷
|
||||||
|
*/
|
||||||
|
public Map<String, Object> getPayload() { return payload; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置业务载荷。
|
||||||
|
*
|
||||||
|
* @param payload 业务载荷
|
||||||
|
*/
|
||||||
|
public void setPayload(Map<String, Object> payload) { this.payload = payload == null ? new LinkedHashMap<>() : payload; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取元数据。
|
||||||
|
*
|
||||||
|
* @return 元数据
|
||||||
|
*/
|
||||||
|
public Map<String, Object> getMetadata() { return metadata; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置元数据。
|
||||||
|
*
|
||||||
|
* @param metadata 元数据
|
||||||
|
*/
|
||||||
|
public void setMetadata(Map<String, Object> metadata) { this.metadata = metadata == null ? new LinkedHashMap<>() : metadata; }
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskStatus;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 异步工具任务运行态存储。
|
||||||
|
*/
|
||||||
|
public interface AgentAsyncToolTaskStore {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建任务记录。
|
||||||
|
*
|
||||||
|
* @param record 任务记录
|
||||||
|
*/
|
||||||
|
void create(AgentAsyncToolTaskRecord record);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前 session 下的任务记录。
|
||||||
|
*
|
||||||
|
* @param sessionId Agent Runtime session ID
|
||||||
|
* @param taskId 任务 ID
|
||||||
|
* @return 任务记录
|
||||||
|
*/
|
||||||
|
Optional<AgentAsyncToolTaskRecord> get(String sessionId, String taskId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新当前 session 下的任务记录。
|
||||||
|
*
|
||||||
|
* @param sessionId Agent Runtime session ID
|
||||||
|
* @param taskId 任务 ID
|
||||||
|
* @param updater 更新函数
|
||||||
|
* @return 更新后的任务记录
|
||||||
|
*/
|
||||||
|
Optional<AgentAsyncToolTaskRecord> update(String sessionId, String taskId, UnaryOperator<AgentAsyncToolTaskRecord> updater);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查询当前 session 下可见任务。
|
||||||
|
*
|
||||||
|
* @param sessionId Agent Runtime session ID
|
||||||
|
* @param status 状态过滤;为空时返回全部未过期任务
|
||||||
|
* @return 任务列表
|
||||||
|
*/
|
||||||
|
List<AgentAsyncToolTaskRecord> list(String sessionId, AsyncToolTaskStatus status);
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolExecutionResult;
|
||||||
|
import tech.easyflow.agent.runtime.tool.PluginToolExecutor;
|
||||||
|
import tech.easyflow.ai.entity.PluginItem;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Plugin 异步工具子能力实现。
|
||||||
|
*/
|
||||||
|
public class PluginAsyncSubTools extends AbstractAgentAsyncSubTools {
|
||||||
|
|
||||||
|
private final PluginItem pluginItem;
|
||||||
|
private final String toolName;
|
||||||
|
private final String displayName;
|
||||||
|
private final PluginToolExecutor pluginToolExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Plugin 异步工具子能力。
|
||||||
|
*
|
||||||
|
* @param pluginItem 插件工具快照
|
||||||
|
* @param toolName runtime 工具名
|
||||||
|
* @param displayName 用户可见名称
|
||||||
|
* @param pluginToolExecutor Plugin 执行器
|
||||||
|
* @param taskStore 任务存储
|
||||||
|
* @param taskExecutor 后台执行器
|
||||||
|
*/
|
||||||
|
public PluginAsyncSubTools(PluginItem pluginItem,
|
||||||
|
String toolName,
|
||||||
|
String displayName,
|
||||||
|
PluginToolExecutor pluginToolExecutor,
|
||||||
|
AgentAsyncToolTaskStore taskStore,
|
||||||
|
ThreadPoolTaskExecutor taskExecutor) {
|
||||||
|
super(taskStore, taskExecutor);
|
||||||
|
this.pluginItem = pluginItem;
|
||||||
|
this.toolName = toolName;
|
||||||
|
this.displayName = displayName;
|
||||||
|
this.pluginToolExecutor = pluginToolExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String toolType() {
|
||||||
|
return AgentToolType.PLUGIN.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String toolName() {
|
||||||
|
return toolName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String displayName() {
|
||||||
|
return displayName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String businessId() {
|
||||||
|
return pluginItem == null || pluginItem.getId() == null ? null : String.valueOf(pluginItem.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected AgentToolExecutionResult executeBusiness(Map<String, Object> arguments) {
|
||||||
|
return pluginToolExecutor.execute(pluginItem, arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskStatus;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.data.redis.core.Cursor;
|
||||||
|
import org.springframework.data.redis.core.ScanOptions;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基于 Redis 单 key 的 Agent 异步工具任务存储。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class RedisAgentAsyncToolTaskStore implements AgentAsyncToolTaskStore {
|
||||||
|
|
||||||
|
private static final String KEY_PREFIX = "easyflow:agent:async-tool:";
|
||||||
|
|
||||||
|
private final StringRedisTemplate stringRedisTemplate;
|
||||||
|
private final ObjectMapper objectMapper;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Redis 任务存储。
|
||||||
|
*
|
||||||
|
* @param stringRedisTemplate Redis 字符串模板
|
||||||
|
* @param objectMapper JSON mapper
|
||||||
|
* @param properties Agent runtime 配置
|
||||||
|
*/
|
||||||
|
public RedisAgentAsyncToolTaskStore(StringRedisTemplate stringRedisTemplate,
|
||||||
|
ObjectMapper objectMapper,
|
||||||
|
AgentRuntimeProperties properties) {
|
||||||
|
this.stringRedisTemplate = stringRedisTemplate;
|
||||||
|
this.objectMapper = objectMapper;
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void create(AgentAsyncToolTaskRecord record) {
|
||||||
|
if (record == null) {
|
||||||
|
throw new BusinessException("异步工具任务不能为空");
|
||||||
|
}
|
||||||
|
String sessionId = requireText(record.getSessionId(), "异步工具任务 sessionId 不能为空");
|
||||||
|
String taskId = requireText(record.getTaskId(), "异步工具任务 taskId 不能为空");
|
||||||
|
record.setSessionScopedKey(key(sessionId, taskId));
|
||||||
|
Duration ttl = taskTtl();
|
||||||
|
record.setTtlSeconds(ttl.toSeconds());
|
||||||
|
write(record, ttl);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> get(String sessionId, String taskId) {
|
||||||
|
String value = stringRedisTemplate.opsForValue().get(key(sessionId, taskId));
|
||||||
|
if (!StringUtils.hasText(value)) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
return Optional.of(read(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> update(String sessionId,
|
||||||
|
String taskId,
|
||||||
|
UnaryOperator<AgentAsyncToolTaskRecord> updater) {
|
||||||
|
Optional<AgentAsyncToolTaskRecord> existing = get(sessionId, taskId);
|
||||||
|
if (existing.isEmpty()) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
AgentAsyncToolTaskRecord updated = updater == null ? existing.get() : updater.apply(existing.get());
|
||||||
|
if (updated == null) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
updated.setUpdatedAt(Instant.now());
|
||||||
|
write(updated, remainingTtl(sessionId, taskId));
|
||||||
|
return Optional.of(updated);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<AgentAsyncToolTaskRecord> list(String sessionId, AsyncToolTaskStatus status) {
|
||||||
|
String safeSessionId = requireText(sessionId, "异步工具任务 sessionId 不能为空");
|
||||||
|
List<AgentAsyncToolTaskRecord> result = new ArrayList<>();
|
||||||
|
ScanOptions options = ScanOptions.scanOptions().match(KEY_PREFIX + safeSessionId + ":*").count(100).build();
|
||||||
|
try (Cursor<String> cursor = stringRedisTemplate.scan(options)) {
|
||||||
|
while (cursor.hasNext()) {
|
||||||
|
String key = cursor.next();
|
||||||
|
String value = stringRedisTemplate.opsForValue().get(key);
|
||||||
|
if (!StringUtils.hasText(value)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
AgentAsyncToolTaskRecord record = read(value);
|
||||||
|
if (status == null || status == record.getStatus()) {
|
||||||
|
result.add(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.sort(Comparator.comparing(AgentAsyncToolTaskRecord::getCreatedAt,
|
||||||
|
Comparator.nullsLast(Comparator.reverseOrder())));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void write(AgentAsyncToolTaskRecord record, Duration ttl) {
|
||||||
|
try {
|
||||||
|
stringRedisTemplate.opsForValue().set(record.getSessionScopedKey(),
|
||||||
|
objectMapper.writeValueAsString(record), Math.max(1L, ttl.toSeconds()), TimeUnit.SECONDS);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new BusinessException("写入异步工具任务状态失败:" + safeMessage(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentAsyncToolTaskRecord read(String value) {
|
||||||
|
try {
|
||||||
|
return objectMapper.readValue(value, AgentAsyncToolTaskRecord.class);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new BusinessException("读取异步工具任务状态失败:" + safeMessage(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Duration remainingTtl(String sessionId, String taskId) {
|
||||||
|
Long seconds = stringRedisTemplate.getExpire(key(sessionId, taskId), TimeUnit.SECONDS);
|
||||||
|
if (seconds == null || seconds <= 0L) {
|
||||||
|
return taskTtl();
|
||||||
|
}
|
||||||
|
return Duration.ofSeconds(seconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Duration taskTtl() {
|
||||||
|
Duration ttl = properties == null ? Duration.ofHours(24) : properties.getAsyncToolTaskTtl();
|
||||||
|
return ttl == null || ttl.isZero() || ttl.isNegative() ? Duration.ofHours(24) : ttl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String key(String sessionId, String taskId) {
|
||||||
|
return KEY_PREFIX
|
||||||
|
+ requireText(sessionId, "异步工具任务 sessionId 不能为空")
|
||||||
|
+ ":"
|
||||||
|
+ requireText(taskId, "异步工具任务 taskId 不能为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
private String requireText(String value, String message) {
|
||||||
|
if (!StringUtils.hasText(value)) {
|
||||||
|
throw new BusinessException(message);
|
||||||
|
}
|
||||||
|
return value.trim();
|
||||||
|
}
|
||||||
|
|
||||||
|
private String safeMessage(Exception e) {
|
||||||
|
return e == null || e.getMessage() == null || e.getMessage().isBlank()
|
||||||
|
? "未知错误"
|
||||||
|
: e.getMessage();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolExecutionResult;
|
||||||
|
import tech.easyflow.agent.runtime.tool.WorkflowToolExecutor;
|
||||||
|
import tech.easyflow.ai.entity.Workflow;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Workflow 异步工具子能力实现。
|
||||||
|
*/
|
||||||
|
public class WorkflowAsyncSubTools extends AbstractAgentAsyncSubTools {
|
||||||
|
|
||||||
|
private final Workflow workflow;
|
||||||
|
private final String toolName;
|
||||||
|
private final String displayName;
|
||||||
|
private final WorkflowToolExecutor workflowToolExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Workflow 异步工具子能力。
|
||||||
|
*
|
||||||
|
* @param workflow 工作流快照
|
||||||
|
* @param toolName runtime 工具名
|
||||||
|
* @param displayName 用户可见名称
|
||||||
|
* @param workflowToolExecutor Workflow 执行器
|
||||||
|
* @param taskStore 任务存储
|
||||||
|
* @param taskExecutor 后台执行器
|
||||||
|
*/
|
||||||
|
public WorkflowAsyncSubTools(Workflow workflow,
|
||||||
|
String toolName,
|
||||||
|
String displayName,
|
||||||
|
WorkflowToolExecutor workflowToolExecutor,
|
||||||
|
AgentAsyncToolTaskStore taskStore,
|
||||||
|
ThreadPoolTaskExecutor taskExecutor) {
|
||||||
|
super(taskStore, taskExecutor);
|
||||||
|
this.workflow = workflow;
|
||||||
|
this.toolName = toolName;
|
||||||
|
this.displayName = displayName;
|
||||||
|
this.workflowToolExecutor = workflowToolExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String toolType() {
|
||||||
|
return AgentToolType.WORKFLOW.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String toolName() {
|
||||||
|
return toolName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String displayName() {
|
||||||
|
return displayName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected String businessId() {
|
||||||
|
return workflow == null || workflow.getId() == null ? null : String.valueOf(workflow.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@inheritDoc}
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected AgentToolExecutionResult executeBusiness(Map<String, Object> arguments) {
|
||||||
|
return workflowToolExecutor.execute(workflow, arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import org.slf4j.LoggerFactory;
|
|||||||
import org.springframework.scheduling.annotation.Scheduled;
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import tech.easyflow.agent.entity.AgentHitlPending;
|
import tech.easyflow.agent.entity.AgentHitlPending;
|
||||||
|
import tech.easyflow.common.cache.DistributedScheduledLock;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ public class AgentHitlPendingExpirationTask {
|
|||||||
* 定期将超时 pending 标记为 EXPIRED。
|
* 定期将超时 pending 标记为 EXPIRED。
|
||||||
*/
|
*/
|
||||||
@Scheduled(fixedDelayString = "${easyflow.agent.runtime.hitl-expire-scan-delay:60000}", initialDelay = 60000L)
|
@Scheduled(fixedDelayString = "${easyflow.agent.runtime.hitl-expire-scan-delay:60000}", initialDelay = 60000L)
|
||||||
|
@DistributedScheduledLock(key = "easyflow:schedule:agent-hitl:expire-pending", leaseSeconds = 300L)
|
||||||
public void expirePending() {
|
public void expirePending() {
|
||||||
try {
|
try {
|
||||||
List<AgentHitlPending> expired = pendingService.expirePending(BATCH_SIZE);
|
List<AgentHitlPending> expired = pendingService.expirePending(BATCH_SIZE);
|
||||||
|
|||||||
@@ -0,0 +1,241 @@
|
|||||||
|
package tech.easyflow.agent.runtime.session;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||||
|
import io.agentscope.core.state.State;
|
||||||
|
import io.agentscope.core.util.JsonUtils;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.StringUtils;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.security.MessageDigest;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 草稿试运行 Redis-only session store。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class DraftAgentSessionStore implements AgentSessionStore {
|
||||||
|
|
||||||
|
private static final String REDIS_PREFIX = "easyflow:agent:draft-session:";
|
||||||
|
private static final String ENVELOPE_VERSION = "1";
|
||||||
|
private static final String SINGLE_STATES = "singleStates";
|
||||||
|
private static final String LIST_STATES = "listStates";
|
||||||
|
|
||||||
|
private final StringRedisTemplate stringRedisTemplate;
|
||||||
|
private final AgentRuntimeProperties properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建草稿试运行 session store。
|
||||||
|
*
|
||||||
|
* @param stringRedisTemplate Redis 模板
|
||||||
|
* @param properties Agent 运行态配置
|
||||||
|
*/
|
||||||
|
public DraftAgentSessionStore(StringRedisTemplate stringRedisTemplate,
|
||||||
|
AgentRuntimeProperties properties) {
|
||||||
|
this.stringRedisTemplate = stringRedisTemplate;
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存单个状态项。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
* @param name 状态名称
|
||||||
|
* @param state 状态值
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void save(String sessionKey, String name, State state) {
|
||||||
|
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || state == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||||
|
singleStates(envelope).put(name, JsonUtils.getJsonCodec().toJson(state));
|
||||||
|
writeCache(sessionKey, envelope);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存状态列表。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
* @param name 状态名称
|
||||||
|
* @param states 状态列表
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void saveList(String sessionKey, String name, List<? extends State> states) {
|
||||||
|
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<String> values = new ArrayList<>();
|
||||||
|
if (states != null) {
|
||||||
|
for (State state : states) {
|
||||||
|
values.add(JsonUtils.getJsonCodec().toJson(state));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||||
|
listStates(envelope).put(name, values);
|
||||||
|
writeCache(sessionKey, envelope);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取单个状态项。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
* @param name 状态名称
|
||||||
|
* @param type 状态类型
|
||||||
|
* @param <T> 状态类型
|
||||||
|
* @return 可选状态
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public <T extends State> Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||||
|
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || type == null) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
Object json = singleStates(loadEnvelope(sessionKey)).get(name);
|
||||||
|
if (!(json instanceof String text) || text.isBlank()) {
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
return Optional.of(JsonUtils.getJsonCodec().fromJson(text, type));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取状态列表。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
* @param name 状态名称
|
||||||
|
* @param itemType 状态元素类型
|
||||||
|
* @param <T> 状态元素类型
|
||||||
|
* @return 状态列表
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public <T extends State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||||
|
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || itemType == null) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
Object raw = listStates(loadEnvelope(sessionKey)).get(name);
|
||||||
|
if (!(raw instanceof List<?> values) || values.isEmpty()) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
List<T> result = new ArrayList<>();
|
||||||
|
for (Object value : values) {
|
||||||
|
if (value instanceof String text && !text.isBlank()) {
|
||||||
|
result.add(JsonUtils.getJsonCodec().fromJson(text, itemType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断会话键是否存在。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
* @return 存在时为 true
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean exists(String sessionKey) {
|
||||||
|
return StringUtils.hasText(sessionKey) && readCache(sessionKey) != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除指定会话键下的全部状态。
|
||||||
|
*
|
||||||
|
* @param sessionKey 会话键
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void delete(String sessionKey) {
|
||||||
|
if (!StringUtils.hasText(sessionKey)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
deleteCache(sessionKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 列出当前存储中的会话键。
|
||||||
|
*
|
||||||
|
* <p>草稿 session 使用哈希 Redis key,不维护反向索引,避免为试运行引入额外持久化状态。</p>
|
||||||
|
*
|
||||||
|
* @return 空集合
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Set<String> listSessionKeys() {
|
||||||
|
return new LinkedHashSet<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> loadEnvelope(String sessionKey) {
|
||||||
|
Map<String, Object> cached = readCache(sessionKey);
|
||||||
|
return cached == null ? emptyEnvelope() : deepCopy(cached);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> emptyEnvelope() {
|
||||||
|
Map<String, Object> envelope = new LinkedHashMap<>();
|
||||||
|
envelope.put("version", ENVELOPE_VERSION);
|
||||||
|
envelope.put(SINGLE_STATES, new LinkedHashMap<String, Object>());
|
||||||
|
envelope.put(LIST_STATES, new LinkedHashMap<String, Object>());
|
||||||
|
return envelope;
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private Map<String, Object> singleStates(Map<String, Object> envelope) {
|
||||||
|
return (Map<String, Object>) envelope.computeIfAbsent(SINGLE_STATES, key -> new LinkedHashMap<String, Object>());
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private Map<String, Object> listStates(Map<String, Object> envelope) {
|
||||||
|
return (Map<String, Object>) envelope.computeIfAbsent(LIST_STATES, key -> new LinkedHashMap<String, Object>());
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private Map<String, Object> readCache(String sessionKey) {
|
||||||
|
try {
|
||||||
|
String value = stringRedisTemplate.opsForValue().get(cacheKey(sessionKey));
|
||||||
|
if (!StringUtils.hasText(value)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return JsonUtils.getJsonCodec().fromJson(value, Map.class);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeCache(String sessionKey, Map<String, Object> envelope) {
|
||||||
|
long seconds = Math.max(1L, properties.getSessionCacheTtl().toSeconds());
|
||||||
|
stringRedisTemplate.opsForValue().set(cacheKey(sessionKey), JsonUtils.getJsonCodec().toJson(envelope),
|
||||||
|
seconds, TimeUnit.SECONDS);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void deleteCache(String sessionKey) {
|
||||||
|
stringRedisTemplate.delete(cacheKey(sessionKey));
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private Map<String, Object> deepCopy(Map<String, Object> source) {
|
||||||
|
if (source == null || source.isEmpty()) {
|
||||||
|
return emptyEnvelope();
|
||||||
|
}
|
||||||
|
return JsonUtils.getJsonCodec().fromJson(JsonUtils.getJsonCodec().toJson(source), Map.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String cacheKey(String sessionKey) {
|
||||||
|
return REDIS_PREFIX + hash(sessionKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String hash(String value) {
|
||||||
|
try {
|
||||||
|
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||||
|
byte[] bytes = digest.digest(value.getBytes(StandardCharsets.UTF_8));
|
||||||
|
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
|
||||||
|
} catch (NoSuchAlgorithmException e) {
|
||||||
|
return value.replace(':', '_');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 工具执行模式。
|
||||||
|
*/
|
||||||
|
public enum AgentToolExecutionMode {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 同步执行。
|
||||||
|
*/
|
||||||
|
SYNC,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 异步执行。
|
||||||
|
*/
|
||||||
|
ASYNC;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析执行模式。
|
||||||
|
*
|
||||||
|
* @param value 原始配置值
|
||||||
|
* @return 执行模式;非法或为空时返回 SYNC
|
||||||
|
*/
|
||||||
|
public static AgentToolExecutionMode from(String value) {
|
||||||
|
if (value == null || value.isBlank()) {
|
||||||
|
return SYNC;
|
||||||
|
}
|
||||||
|
for (AgentToolExecutionMode mode : values()) {
|
||||||
|
if (mode.name().equalsIgnoreCase(value.trim())) {
|
||||||
|
return mode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SYNC;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 业务工具执行结果。
|
||||||
|
*/
|
||||||
|
public class AgentToolExecutionResult {
|
||||||
|
|
||||||
|
private Object result;
|
||||||
|
private String businessExecutionId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建执行结果。
|
||||||
|
*
|
||||||
|
* @param result 业务结果
|
||||||
|
* @param businessExecutionId 业务执行记录 ID
|
||||||
|
*/
|
||||||
|
public AgentToolExecutionResult(Object result, String businessExecutionId) {
|
||||||
|
this.result = result;
|
||||||
|
this.businessExecutionId = businessExecutionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务结果。
|
||||||
|
*
|
||||||
|
* @return 业务结果
|
||||||
|
*/
|
||||||
|
public Object getResult() {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置业务结果。
|
||||||
|
*
|
||||||
|
* @param result 业务结果
|
||||||
|
*/
|
||||||
|
public void setResult(Object result) {
|
||||||
|
this.result = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取业务执行记录 ID。
|
||||||
|
*
|
||||||
|
* @return 业务执行记录 ID
|
||||||
|
*/
|
||||||
|
public String getBusinessExecutionId() {
|
||||||
|
return businessExecutionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置业务执行记录 ID。
|
||||||
|
*
|
||||||
|
* @param businessExecutionId 业务执行记录 ID
|
||||||
|
*/
|
||||||
|
public void setBusinessExecutionId(String businessExecutionId) {
|
||||||
|
this.businessExecutionId = businessExecutionId;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolInvoker;
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolSpec;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 工具运行时编译结果。
|
||||||
|
*/
|
||||||
|
public class AgentToolRuntimeCompilation {
|
||||||
|
|
||||||
|
private List<AgentToolSpec> toolSpecs = new ArrayList<>();
|
||||||
|
private List<McpSpec> mcpSpecs = new ArrayList<>();
|
||||||
|
private Map<String, AgentToolInvoker> toolInvokers = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取普通工具声明。
|
||||||
|
*
|
||||||
|
* @return 普通工具声明
|
||||||
|
*/
|
||||||
|
public List<AgentToolSpec> getToolSpecs() {
|
||||||
|
return toolSpecs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置普通工具声明。
|
||||||
|
*
|
||||||
|
* @param toolSpecs 普通工具声明
|
||||||
|
*/
|
||||||
|
public void setToolSpecs(List<AgentToolSpec> toolSpecs) {
|
||||||
|
this.toolSpecs = toolSpecs == null ? new ArrayList<>() : toolSpecs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 MCP 声明。
|
||||||
|
*
|
||||||
|
* @return MCP 声明
|
||||||
|
*/
|
||||||
|
public List<McpSpec> getMcpSpecs() {
|
||||||
|
return mcpSpecs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 MCP 声明。
|
||||||
|
*
|
||||||
|
* @param mcpSpecs MCP 声明
|
||||||
|
*/
|
||||||
|
public void setMcpSpecs(List<McpSpec> mcpSpecs) {
|
||||||
|
this.mcpSpecs = mcpSpecs == null ? new ArrayList<>() : mcpSpecs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取工具调用器。
|
||||||
|
*
|
||||||
|
* @return 工具调用器
|
||||||
|
*/
|
||||||
|
public Map<String, AgentToolInvoker> getToolInvokers() {
|
||||||
|
return toolInvokers;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置工具调用器。
|
||||||
|
*
|
||||||
|
* @param toolInvokers 工具调用器
|
||||||
|
*/
|
||||||
|
public void setToolInvokers(Map<String, AgentToolInvoker> toolInvokers) {
|
||||||
|
this.toolInvokers = toolInvokers == null ? new LinkedHashMap<>() : toolInvokers;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,627 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.hitl.AgentToolApprovalRequest;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||||
|
import com.easyagents.agent.runtime.tool.*;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolSpec;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolSpecExpander;
|
||||||
|
import com.easyagents.core.model.chat.tool.Parameter;
|
||||||
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.agent.entity.Agent;
|
||||||
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.runtime.asynctool.AgentAsyncToolTaskStore;
|
||||||
|
import tech.easyflow.agent.runtime.asynctool.PluginAsyncSubTools;
|
||||||
|
import tech.easyflow.agent.runtime.asynctool.WorkflowAsyncSubTools;
|
||||||
|
import tech.easyflow.ai.easyagents.tool.ChatToolNameHelper;
|
||||||
|
import tech.easyflow.ai.entity.Mcp;
|
||||||
|
import tech.easyflow.ai.entity.PluginItem;
|
||||||
|
import tech.easyflow.ai.entity.Workflow;
|
||||||
|
import tech.easyflow.ai.service.McpService;
|
||||||
|
import tech.easyflow.ai.service.PluginItemService;
|
||||||
|
import tech.easyflow.ai.service.WorkflowService;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import javax.annotation.Resource;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 工具运行时编译器。
|
||||||
|
*/
|
||||||
|
@Component
|
||||||
|
public class AgentToolRuntimeCompiler {
|
||||||
|
|
||||||
|
private static final Pattern MCP_INPUT_PATTERN = Pattern.compile("\\$\\{input:([A-Za-z0-9_.-]+)}");
|
||||||
|
private static final Pattern ASYNC_SAFE_NAME = Pattern.compile("^[a-z][a-z0-9_]*$");
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private WorkflowService workflowService;
|
||||||
|
@Resource
|
||||||
|
private PluginItemService pluginItemService;
|
||||||
|
@Resource
|
||||||
|
private McpService mcpService;
|
||||||
|
@Resource
|
||||||
|
private ObjectMapper objectMapper;
|
||||||
|
@Resource
|
||||||
|
private WorkflowToolExecutor workflowToolExecutor;
|
||||||
|
@Resource
|
||||||
|
private PluginToolExecutor pluginToolExecutor;
|
||||||
|
@Resource
|
||||||
|
private AgentAsyncToolTaskStore asyncToolTaskStore;
|
||||||
|
@Resource(name = "agentAsyncToolExecutor")
|
||||||
|
private ThreadPoolTaskExecutor agentAsyncToolExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 编译 Agent 工具配置。
|
||||||
|
*
|
||||||
|
* @param agent Agent 业务定义
|
||||||
|
* @return 工具编译结果
|
||||||
|
*/
|
||||||
|
public AgentToolRuntimeCompilation compile(Agent agent) {
|
||||||
|
AgentToolRuntimeCompilation compilation = new AgentToolRuntimeCompilation();
|
||||||
|
if (agent == null || agent.getToolBindings() == null) {
|
||||||
|
return compilation;
|
||||||
|
}
|
||||||
|
List<AgentToolSpec> specs = new ArrayList<>();
|
||||||
|
Map<String, AgentToolInvoker> invokers = new LinkedHashMap<>();
|
||||||
|
List<McpSpec> mcpSpecs = new ArrayList<>();
|
||||||
|
Map<BigInteger, McpSpec> mcpSpecMap = new LinkedHashMap<>();
|
||||||
|
Set<String> compiledToolNames = new LinkedHashSet<>();
|
||||||
|
AsyncToolSpecExpander asyncExpander = new AsyncToolSpecExpander();
|
||||||
|
for (AgentToolBinding binding : agent.getToolBindings()) {
|
||||||
|
if (!Boolean.TRUE.equals(binding.getEnabled())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
AgentToolType type = AgentToolType.from(binding.getToolType());
|
||||||
|
if (type == AgentToolType.MCP) {
|
||||||
|
McpSpec mcpSpec = mcpSpecMap.computeIfAbsent(binding.getTargetId(), ignored -> buildMcpSpec(binding));
|
||||||
|
applyMcpToolBinding(mcpSpec, binding);
|
||||||
|
if (!mcpSpecs.contains(mcpSpec)) {
|
||||||
|
mcpSpecs.add(mcpSpec);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (executionMode(binding) == AgentToolExecutionMode.ASYNC) {
|
||||||
|
AsyncToolSpec asyncSpec = buildAsyncToolSpec(type, binding);
|
||||||
|
addExpandedTools(specs, invokers, compiledToolNames,
|
||||||
|
asyncExpander.expandSpecs(asyncSpec),
|
||||||
|
asyncExpander.expandInvokers(asyncSpec));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
CompiledSyncTool syncTool = buildSyncTool(type, binding);
|
||||||
|
addCompiledTool(specs, invokers, compiledToolNames, syncTool.spec(), syncTool.invoker());
|
||||||
|
}
|
||||||
|
compilation.setToolSpecs(specs);
|
||||||
|
compilation.setMcpSpecs(mcpSpecs);
|
||||||
|
compilation.setToolInvokers(invokers);
|
||||||
|
return compilation;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addExpandedTools(List<AgentToolSpec> specs,
|
||||||
|
Map<String, AgentToolInvoker> invokers,
|
||||||
|
Set<String> compiledToolNames,
|
||||||
|
List<AgentToolSpec> expandedSpecs,
|
||||||
|
Map<String, AgentToolInvoker> expandedInvokers) {
|
||||||
|
for (AgentToolSpec spec : expandedSpecs) {
|
||||||
|
addCompiledTool(specs, invokers, compiledToolNames, spec,
|
||||||
|
expandedInvokers == null ? null : expandedInvokers.get(spec.getName()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addCompiledTool(List<AgentToolSpec> specs,
|
||||||
|
Map<String, AgentToolInvoker> invokers,
|
||||||
|
Set<String> compiledToolNames,
|
||||||
|
AgentToolSpec spec,
|
||||||
|
AgentToolInvoker invoker) {
|
||||||
|
String name = spec == null ? null : spec.getName();
|
||||||
|
if (name == null || name.isBlank()) {
|
||||||
|
throw new BusinessException("Agent 工具运行名不能为空");
|
||||||
|
}
|
||||||
|
if (!compiledToolNames.add(name)) {
|
||||||
|
throw new BusinessException("Agent 工具运行名冲突:" + name + ",请调整工具名称");
|
||||||
|
}
|
||||||
|
specs.add(spec);
|
||||||
|
if (invoker != null) {
|
||||||
|
invokers.put(name, invoker);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CompiledSyncTool buildSyncTool(AgentToolType type, AgentToolBinding binding) {
|
||||||
|
if (type == AgentToolType.WORKFLOW) {
|
||||||
|
Workflow workflow = requireWorkflow(binding);
|
||||||
|
Tool tool = workflowToolExecutor.buildTool(workflow);
|
||||||
|
AgentToolSpec spec = toToolSpec(tool, binding);
|
||||||
|
AgentToolInvoker invoker = (arguments, context) -> invokeSafely(spec.getName(),
|
||||||
|
() -> workflowToolExecutor.execute(workflow, arguments).getResult());
|
||||||
|
return new CompiledSyncTool(spec, invoker);
|
||||||
|
}
|
||||||
|
if (type == AgentToolType.PLUGIN) {
|
||||||
|
PluginItem pluginItem = requirePlugin(binding);
|
||||||
|
Tool tool = pluginToolExecutor.buildTool(pluginItem);
|
||||||
|
AgentToolSpec spec = toToolSpec(tool, binding);
|
||||||
|
AgentToolInvoker invoker = (arguments, context) -> invokeSafely(spec.getName(),
|
||||||
|
() -> pluginToolExecutor.execute(pluginItem, arguments).getResult());
|
||||||
|
return new CompiledSyncTool(spec, invoker);
|
||||||
|
}
|
||||||
|
throw new BusinessException("不支持的 Agent 工具类型:" + type.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolSpec buildAsyncToolSpec(AgentToolType type, AgentToolBinding binding) {
|
||||||
|
if (type == AgentToolType.WORKFLOW) {
|
||||||
|
Workflow workflow = requireWorkflow(binding);
|
||||||
|
Tool tool = workflowToolExecutor.buildTool(workflow);
|
||||||
|
String asyncName = asyncToolName(tool, binding, "workflow");
|
||||||
|
String toolDisplayName = displayName(tool, workflow.getTitle());
|
||||||
|
AsyncToolSpec spec = baseAsyncSpec(asyncName, tool, binding, toolDisplayName);
|
||||||
|
spec.setSubTools(new WorkflowAsyncSubTools(workflow, asyncName, toolDisplayName,
|
||||||
|
workflowToolExecutor, asyncToolTaskStore, agentAsyncToolExecutor));
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
if (type == AgentToolType.PLUGIN) {
|
||||||
|
PluginItem pluginItem = requirePlugin(binding);
|
||||||
|
Tool tool = pluginToolExecutor.buildTool(pluginItem);
|
||||||
|
String asyncName = asyncToolName(tool, binding, "plugin");
|
||||||
|
String toolDisplayName = displayName(tool, pluginItem.getName());
|
||||||
|
AsyncToolSpec spec = baseAsyncSpec(asyncName, tool, binding, toolDisplayName);
|
||||||
|
spec.setSubTools(new PluginAsyncSubTools(pluginItem, asyncName, toolDisplayName,
|
||||||
|
pluginToolExecutor, asyncToolTaskStore, agentAsyncToolExecutor));
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
throw new BusinessException("不支持的 Agent 异步工具类型:" + type.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolSpec baseAsyncSpec(String asyncName, Tool tool, AgentToolBinding binding, String toolDisplayName) {
|
||||||
|
AsyncToolSpec spec = new AsyncToolSpec();
|
||||||
|
spec.setName(asyncName);
|
||||||
|
spec.setDescription(safeDescription(tool == null ? null : tool.getDescription()));
|
||||||
|
spec.setSubmitParametersSchema(toSchema(tool == null ? null : tool.getParameters()));
|
||||||
|
spec.setApprovalRequired(Boolean.TRUE.equals(binding.getHitlEnabled()));
|
||||||
|
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||||
|
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||||
|
}
|
||||||
|
spec.getMetadata().put("bindingId", binding.getId());
|
||||||
|
spec.getMetadata().put("targetId", binding.getTargetId());
|
||||||
|
spec.getMetadata().put("toolType", binding.getToolType());
|
||||||
|
// 异步子工具名服务 runtime 调用,事件和聊天展示必须保留业务名称。
|
||||||
|
spec.getMetadata().put("toolDisplayName", toolDisplayName);
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolResult invokeSafely(String toolName, ToolCall call) {
|
||||||
|
try {
|
||||||
|
Object result = call.invoke();
|
||||||
|
return AgentToolResult.success(result == null ? "" : String.valueOf(result));
|
||||||
|
} catch (Exception e) {
|
||||||
|
return AgentToolResult.failure(e.getMessage() == null ? "工具执行失败" : e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolExecutionMode executionMode(AgentToolBinding binding) {
|
||||||
|
Object value = binding == null || binding.getOptionsJson() == null ? null : binding.getOptionsJson().get("executionMode");
|
||||||
|
return AgentToolExecutionMode.from(value == null ? null : String.valueOf(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Workflow requireWorkflow(AgentToolBinding binding) {
|
||||||
|
Workflow workflow = snapshotOrPublishedWorkflow(binding);
|
||||||
|
if (workflow == null) {
|
||||||
|
throw new BusinessException("绑定工作流不存在");
|
||||||
|
}
|
||||||
|
return workflow;
|
||||||
|
}
|
||||||
|
|
||||||
|
private PluginItem requirePlugin(AgentToolBinding binding) {
|
||||||
|
PluginItem pluginItem = snapshotOrCurrentPlugin(binding);
|
||||||
|
if (pluginItem == null) {
|
||||||
|
throw new BusinessException("绑定插件不存在");
|
||||||
|
}
|
||||||
|
return pluginItem;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolSpec toToolSpec(Tool tool, AgentToolBinding binding) {
|
||||||
|
AgentToolSpec spec = new AgentToolSpec();
|
||||||
|
String name = resolveRuntimeToolName(tool, binding);
|
||||||
|
spec.setName(name);
|
||||||
|
spec.setDescription(safeDescription(tool == null ? null : tool.getDescription()));
|
||||||
|
spec.setCategory(AgentToolCategory.valueOf(AgentToolType.from(binding.getToolType()).name()));
|
||||||
|
spec.setParametersSchema(toSchema(tool == null ? null : tool.getParameters()));
|
||||||
|
spec.setApprovalRequired(Boolean.TRUE.equals(binding.getHitlEnabled()));
|
||||||
|
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||||
|
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||||
|
}
|
||||||
|
spec.getMetadata().put("bindingId", binding.getId());
|
||||||
|
spec.getMetadata().put("targetId", binding.getTargetId());
|
||||||
|
spec.getMetadata().put("toolType", binding.getToolType());
|
||||||
|
spec.getMetadata().put("toolDisplayName", displayName(tool, binding.getToolName()));
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolApprovalRequest buildBindingApprovalRequest(AgentToolBinding binding) {
|
||||||
|
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||||
|
String name = binding == null ? "工具" : binding.getToolName();
|
||||||
|
request.setApprovalPrompt(stringValue(binding == null ? null : binding.getHitlConfigJson(), "prompt", "是否批准执行工具:" + name));
|
||||||
|
Map<String, Object> metadata = sanitizedHitlMetadata(binding == null ? null : binding.getHitlConfigJson());
|
||||||
|
if (binding != null) {
|
||||||
|
metadata.put("toolType", binding.getToolType());
|
||||||
|
metadata.put("bindingId", binding.getId());
|
||||||
|
metadata.put("targetId", binding.getTargetId());
|
||||||
|
}
|
||||||
|
request.setMetadata(metadata);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
private McpSpec buildMcpSpec(AgentToolBinding binding) {
|
||||||
|
Mcp mcp = snapshotOrCurrentMcp(binding);
|
||||||
|
if (mcp == null) {
|
||||||
|
throw new BusinessException("绑定 MCP 不存在");
|
||||||
|
}
|
||||||
|
Map.Entry<String, Map<String, Object>> server = firstMcpServer(mcp);
|
||||||
|
Map<String, Object> serverConfig = server.getValue();
|
||||||
|
McpSpec spec = new McpSpec();
|
||||||
|
spec.setName(mcpRuntimeName(mcp));
|
||||||
|
spec.setDescription(firstNonBlank(mcp.getDescription(), mcp.getTitle()));
|
||||||
|
spec.setTransportType(parseMcpTransportType(mcp, serverConfig));
|
||||||
|
spec.setCommand(resolveMcpInput(stringValue(serverConfig, "command", null)));
|
||||||
|
spec.setArgs(resolveMcpInputs(stringListValue(serverConfig, "args")));
|
||||||
|
spec.setEnv(resolveMcpInputMap(stringMapValue(serverConfig, "env")));
|
||||||
|
spec.setUrl(resolveMcpInput(stringValue(serverConfig, "url", null)));
|
||||||
|
spec.setHeaders(resolveMcpInputMap(stringMapValue(serverConfig, "headers")));
|
||||||
|
spec.setQueryParams(resolveMcpInputMap(stringMapValue(serverConfig, "queryParams")));
|
||||||
|
Duration timeout = durationValue(serverConfig, "timeout");
|
||||||
|
if (timeout != null) {
|
||||||
|
spec.setTimeout(timeout);
|
||||||
|
}
|
||||||
|
Duration initializationTimeout = durationValue(serverConfig, "initializationTimeout");
|
||||||
|
if (initializationTimeout != null) {
|
||||||
|
spec.setInitializationTimeout(initializationTimeout);
|
||||||
|
}
|
||||||
|
spec.setGroupName(mcpRuntimeName(mcp));
|
||||||
|
spec.setApprovalRequired(Boolean.TRUE.equals(mcp.getApprovalRequired()));
|
||||||
|
spec.setApprovalRequest(buildMcpApprovalRequest(mcp));
|
||||||
|
spec.setToolNamePrefix(mcpRuntimeToolPrefix(mcp.getId()));
|
||||||
|
spec.getMetadata().put("toolType", AgentToolType.MCP.name());
|
||||||
|
spec.getMetadata().put("mcpId", String.valueOf(mcp.getId()));
|
||||||
|
spec.getMetadata().put("mcpTitle", mcp.getTitle());
|
||||||
|
spec.getMetadata().put("serverName", server.getKey());
|
||||||
|
return spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void applyMcpToolBinding(McpSpec spec, AgentToolBinding binding) {
|
||||||
|
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||||
|
spec.setApprovalRequired(true);
|
||||||
|
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolApprovalRequest buildMcpApprovalRequest(Mcp mcp) {
|
||||||
|
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||||
|
request.setApprovalPrompt("是否批准执行 MCP 工具:" + firstNonBlank(mcp.getTitle(), mcpRuntimeName(mcp)));
|
||||||
|
Map<String, Object> metadata = new LinkedHashMap<>();
|
||||||
|
metadata.put("toolType", AgentToolType.MCP.name());
|
||||||
|
metadata.put("mcpId", String.valueOf(mcp.getId()));
|
||||||
|
metadata.put("mcpTitle", mcp.getTitle());
|
||||||
|
request.setMetadata(metadata);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Workflow snapshotOrPublishedWorkflow(AgentToolBinding binding) {
|
||||||
|
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||||
|
Workflow workflow = objectMapper.convertValue(binding.getResourceSnapshot(), Workflow.class);
|
||||||
|
workflow.setId(firstNonNull(workflow.getId(), binding.getTargetId()));
|
||||||
|
return workflow;
|
||||||
|
}
|
||||||
|
return workflowService.getPublishedById(binding.getTargetId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private PluginItem snapshotOrCurrentPlugin(AgentToolBinding binding) {
|
||||||
|
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||||
|
PluginItem pluginItem = objectMapper.convertValue(binding.getResourceSnapshot(), PluginItem.class);
|
||||||
|
pluginItem.setId(firstNonNull(pluginItem.getId(), binding.getTargetId()));
|
||||||
|
return pluginItem;
|
||||||
|
}
|
||||||
|
return pluginItemService.getById(binding.getTargetId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mcp snapshotOrCurrentMcp(AgentToolBinding binding) {
|
||||||
|
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||||
|
Mcp mcp = objectMapper.convertValue(binding.getResourceSnapshot(), Mcp.class);
|
||||||
|
mcp.setId(firstNonNull(mcp.getId(), binding.getTargetId()));
|
||||||
|
return mcp;
|
||||||
|
}
|
||||||
|
return mcpService.getById(binding.getTargetId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> toSchema(Parameter[] parameters) {
|
||||||
|
Map<String, Object> schema = new LinkedHashMap<>();
|
||||||
|
Map<String, Object> properties = new LinkedHashMap<>();
|
||||||
|
List<String> required = new ArrayList<>();
|
||||||
|
if (parameters != null) {
|
||||||
|
for (Parameter parameter : parameters) {
|
||||||
|
properties.put(parameter.getName(), parameterSchema(parameter));
|
||||||
|
if (parameter.isRequired()) {
|
||||||
|
required.add(parameter.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schema.put("type", "object");
|
||||||
|
schema.put("properties", properties);
|
||||||
|
schema.put("required", required);
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> parameterSchema(Parameter parameter) {
|
||||||
|
Map<String, Object> schema = new LinkedHashMap<>();
|
||||||
|
schema.put("type", parameter.getType() == null ? "string" : parameter.getType());
|
||||||
|
putOptionalString(schema, "description", parameter.getDescription());
|
||||||
|
if (parameter.getChildren() != null && !parameter.getChildren().isEmpty()) {
|
||||||
|
Map<String, Object> children = new LinkedHashMap<>();
|
||||||
|
for (Parameter child : parameter.getChildren()) {
|
||||||
|
if (child != null && child.getName() != null && !child.getName().isBlank()) {
|
||||||
|
children.put(child.getName(), parameterSchema(child));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ("array".equalsIgnoreCase(parameter.getType())) {
|
||||||
|
schema.put("items", firstArrayItemSchema(parameter.getChildren()));
|
||||||
|
} else {
|
||||||
|
schema.put("properties", children);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> firstArrayItemSchema(List<Parameter> children) {
|
||||||
|
return children.stream().filter(Objects::nonNull).findFirst()
|
||||||
|
.map(this::parameterSchema)
|
||||||
|
.orElse(Map.of("type", "string"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void putOptionalString(Map<String, Object> target, String key, String value) {
|
||||||
|
if (value != null && !value.isBlank()) {
|
||||||
|
target.put(key, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveRuntimeToolName(Tool tool, AgentToolBinding binding) {
|
||||||
|
String bindingName = binding == null ? null : binding.getToolName();
|
||||||
|
if (ChatToolNameHelper.isSafeToolName(bindingName)) {
|
||||||
|
return bindingName;
|
||||||
|
}
|
||||||
|
String toolName = tool == null ? null : tool.getName();
|
||||||
|
if (ChatToolNameHelper.isSafeToolName(toolName)) {
|
||||||
|
return toolName;
|
||||||
|
}
|
||||||
|
BigInteger targetId = binding == null ? null : binding.getTargetId();
|
||||||
|
return ChatToolNameHelper.buildFallbackName("tool", targetId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String asyncToolName(Tool tool, AgentToolBinding binding, String fallbackPrefix) {
|
||||||
|
String base = resolveRuntimeToolName(tool, binding).toLowerCase(Locale.ROOT)
|
||||||
|
.replaceAll("[^a-z0-9_]", "_")
|
||||||
|
.replaceAll("_+", "_");
|
||||||
|
if (!base.isBlank() && Character.isDigit(base.charAt(0))) {
|
||||||
|
base = fallbackPrefix + "_" + base;
|
||||||
|
}
|
||||||
|
if (ASYNC_SAFE_NAME.matcher(base).matches()) {
|
||||||
|
return base;
|
||||||
|
}
|
||||||
|
return fallbackPrefix + "_" + (binding == null || binding.getTargetId() == null ? "unknown" : binding.getTargetId());
|
||||||
|
}
|
||||||
|
|
||||||
|
private String displayName(Tool tool, String fallback) {
|
||||||
|
String value = tool == null ? null : tool.getName();
|
||||||
|
return firstNonBlank(firstNonBlank(fallback, value), "工具调用");
|
||||||
|
}
|
||||||
|
|
||||||
|
private String safeDescription(String description) {
|
||||||
|
return description == null || description.isBlank() ? "EasyFlow Agent 工具" : description;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> sanitizedHitlMetadata(Map<String, Object> config) {
|
||||||
|
Map<String, Object> metadata = new LinkedHashMap<>();
|
||||||
|
if (config != null) {
|
||||||
|
config.forEach((key, value) -> {
|
||||||
|
if (!isHitlPromptKey(key)) {
|
||||||
|
metadata.put(key, value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isHitlPromptKey(String key) {
|
||||||
|
if (key == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
String normalized = key.trim();
|
||||||
|
return "prompt".equalsIgnoreCase(normalized)
|
||||||
|
|| "question".equalsIgnoreCase(normalized)
|
||||||
|
|| "approvalPrompt".equalsIgnoreCase(normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map.Entry<String, Map<String, Object>> firstMcpServer(Mcp mcp) {
|
||||||
|
Map<String, Object> config = parseMcpConfig(mcp);
|
||||||
|
Map<String, Object> servers = mapValue(config, "mcpServers");
|
||||||
|
if (servers.isEmpty()) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 中没有找到任何 MCP 服务名称");
|
||||||
|
}
|
||||||
|
Map.Entry<String, Object> first = servers.entrySet().iterator().next();
|
||||||
|
if (!(first.getValue() instanceof Map<?, ?> rawServer)) {
|
||||||
|
throw new BusinessException("MCP 服务配置必须是对象:" + first.getKey());
|
||||||
|
}
|
||||||
|
Map<String, Object> serverConfig = new LinkedHashMap<>();
|
||||||
|
rawServer.forEach((key, value) -> serverConfig.put(String.valueOf(key), value));
|
||||||
|
return Map.entry(first.getKey(), serverConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> parseMcpConfig(Mcp mcp) {
|
||||||
|
String configJson = mcp == null ? null : mcp.getConfigJson();
|
||||||
|
if (configJson == null || configJson.isBlank()) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 不能为空");
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return objectMapper.readValue(configJson, new com.fasterxml.jackson.core.type.TypeReference<>() {});
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new BusinessException("MCP 配置 JSON 格式错误");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private McpTransportType parseMcpTransportType(Mcp mcp, Map<String, Object> serverConfig) {
|
||||||
|
String transport = firstNonBlank(mcp == null ? null : mcp.getTransportType(), stringValue(serverConfig, "transport", null));
|
||||||
|
return McpTransportType.from(transport);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String mcpRuntimeName(Mcp mcp) {
|
||||||
|
BigInteger id = mcp == null ? null : mcp.getId();
|
||||||
|
return "mcp_" + safeToolNameSegment(id == null ? "unknown" : String.valueOf(id));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String mcpRuntimeToolPrefix(BigInteger mcpId) {
|
||||||
|
return "mcp_" + safeToolNameSegment(String.valueOf(mcpId)) + "_";
|
||||||
|
}
|
||||||
|
|
||||||
|
private String safeToolNameSegment(String value) {
|
||||||
|
String normalized = String.valueOf(value == null ? "" : value).trim()
|
||||||
|
.replaceAll("[^A-Za-z0-9_-]", "_")
|
||||||
|
.replaceAll("_+", "_");
|
||||||
|
return normalized.isBlank() ? "tool" : normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> stringListValue(Map<String, Object> map, String key) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
if (value instanceof Collection<?> collection) {
|
||||||
|
List<String> result = new ArrayList<>();
|
||||||
|
for (Object item : collection) {
|
||||||
|
if (item != null) {
|
||||||
|
result.add(String.valueOf(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
throw new BusinessException("Agent 配置字段必须是数组:" + key);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Duration durationValue(Map<String, Object> map, String key) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (value instanceof Number number) {
|
||||||
|
return Duration.ofSeconds(number.longValue());
|
||||||
|
}
|
||||||
|
String text = String.valueOf(value).trim();
|
||||||
|
if (text.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return Duration.parse(text);
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
try {
|
||||||
|
return Duration.ofSeconds(Long.parseLong(text));
|
||||||
|
} catch (NumberFormatException e) {
|
||||||
|
throw new BusinessException("Agent 配置字段必须是秒数或 Duration:" + key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> resolveMcpInputs(List<String> values) {
|
||||||
|
if (values == null || values.isEmpty()) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
List<String> result = new ArrayList<>(values.size());
|
||||||
|
for (String value : values) {
|
||||||
|
result.add(resolveMcpInput(value));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> resolveMcpInputMap(Map<String, String> values) {
|
||||||
|
if (values == null || values.isEmpty()) {
|
||||||
|
return new LinkedHashMap<>();
|
||||||
|
}
|
||||||
|
Map<String, String> result = new LinkedHashMap<>();
|
||||||
|
values.forEach((key, value) -> result.put(key, resolveMcpInput(value)));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveMcpInput(String value) {
|
||||||
|
if (value == null || value.isBlank()) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
Matcher matcher = MCP_INPUT_PATTERN.matcher(value);
|
||||||
|
StringBuffer resolved = new StringBuffer();
|
||||||
|
while (matcher.find()) {
|
||||||
|
String inputKey = matcher.group(1);
|
||||||
|
String resolvedValue = System.getProperty("mcp.input." + inputKey);
|
||||||
|
if (resolvedValue == null || resolvedValue.isBlank()) {
|
||||||
|
throw new BusinessException("MCP 输入变量未解析:" + inputKey);
|
||||||
|
}
|
||||||
|
matcher.appendReplacement(resolved, Matcher.quoteReplacement(resolvedValue));
|
||||||
|
}
|
||||||
|
matcher.appendTail(resolved);
|
||||||
|
return resolved.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> mapValue(Map<String, Object> map, String key) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return new LinkedHashMap<>();
|
||||||
|
}
|
||||||
|
if (value instanceof Map<?, ?> raw) {
|
||||||
|
Map<String, Object> result = new LinkedHashMap<>();
|
||||||
|
raw.forEach((rawKey, rawValue) -> result.put(String.valueOf(rawKey), rawValue));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
throw new BusinessException("Agent 配置字段必须是对象:" + key);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, String> stringMapValue(Map<String, Object> map, String key) {
|
||||||
|
Map<String, Object> raw = mapValue(map, key);
|
||||||
|
Map<String, String> result = new LinkedHashMap<>();
|
||||||
|
raw.forEach((rawKey, rawValue) -> {
|
||||||
|
if (rawValue != null) {
|
||||||
|
result.put(rawKey, String.valueOf(rawValue));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String stringValue(Map<String, Object> map, String key, String defaultValue) {
|
||||||
|
Object value = map == null ? null : map.get(key);
|
||||||
|
if (value == null) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
String text = String.valueOf(value);
|
||||||
|
return text.isBlank() ? defaultValue : text;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String firstNonBlank(String first, String second) {
|
||||||
|
return first == null || first.isBlank() ? second : first;
|
||||||
|
}
|
||||||
|
|
||||||
|
private BigInteger firstNonNull(BigInteger first, BigInteger second) {
|
||||||
|
return first == null ? second : first;
|
||||||
|
}
|
||||||
|
|
||||||
|
private record CompiledSyncTool(AgentToolSpec spec, AgentToolInvoker invoker) {
|
||||||
|
}
|
||||||
|
|
||||||
|
private interface ToolCall {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 调用工具。
|
||||||
|
*
|
||||||
|
* @return 工具结果
|
||||||
|
*/
|
||||||
|
Object invoke();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import tech.easyflow.ai.entity.PluginItem;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent Plugin 工具执行器。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class PluginToolExecutor {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建 Plugin 工具声明来源。
|
||||||
|
*
|
||||||
|
* @param pluginItem 插件工具
|
||||||
|
* @return 工具声明来源
|
||||||
|
*/
|
||||||
|
public Tool buildTool(PluginItem pluginItem) {
|
||||||
|
return pluginItem.toFunction();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 Plugin 工具。
|
||||||
|
*
|
||||||
|
* @param pluginItem 插件工具
|
||||||
|
* @param arguments 执行参数
|
||||||
|
* @return 执行结果
|
||||||
|
*/
|
||||||
|
public AgentToolExecutionResult execute(PluginItem pluginItem, Map<String, Object> arguments) {
|
||||||
|
Object result = buildTool(pluginItem).invoke(arguments == null ? Map.of() : arguments);
|
||||||
|
return new AgentToolExecutionResult(result, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
import com.easyagents.flow.core.chain.runtime.ChainExecutor;
|
||||||
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import tech.easyflow.ai.easyagents.tool.WorkflowTool;
|
||||||
|
import tech.easyflow.ai.easyagentsflow.support.PublishedWorkflowDefinitionIds;
|
||||||
|
import tech.easyflow.ai.entity.Workflow;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent Workflow 工具执行器。
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
public class WorkflowToolExecutor {
|
||||||
|
|
||||||
|
private final ChainExecutor chainExecutor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Workflow 工具执行器。
|
||||||
|
*
|
||||||
|
* @param chainExecutor 工作流执行器
|
||||||
|
*/
|
||||||
|
public WorkflowToolExecutor(ChainExecutor chainExecutor) {
|
||||||
|
this.chainExecutor = chainExecutor;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建 Workflow 工具声明来源。
|
||||||
|
*
|
||||||
|
* @param workflow 工作流
|
||||||
|
* @return 工具声明来源
|
||||||
|
*/
|
||||||
|
public Tool buildTool(Workflow workflow) {
|
||||||
|
return new WorkflowTool(workflow, true, definitionId(workflow));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 Workflow 工具。
|
||||||
|
*
|
||||||
|
* @param workflow 工作流
|
||||||
|
* @param arguments 执行参数
|
||||||
|
* @return 执行结果
|
||||||
|
*/
|
||||||
|
public AgentToolExecutionResult execute(Workflow workflow, Map<String, Object> arguments) {
|
||||||
|
Object result = chainExecutor.execute(definitionId(workflow), arguments == null ? Map.of() : arguments);
|
||||||
|
return new AgentToolExecutionResult(result, resolveBusinessExecutionId(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String definitionId(Workflow workflow) {
|
||||||
|
return PublishedWorkflowDefinitionIds.published(String.valueOf(workflow == null ? null : workflow.getId()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private String resolveBusinessExecutionId(Object result) {
|
||||||
|
if (result instanceof Map<?, ?> map) {
|
||||||
|
Object value = firstValue(map, "executionId", "executeId", "chainId", "runId");
|
||||||
|
return value == null ? null : String.valueOf(value);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object firstValue(Map<?, ?> map, String... keys) {
|
||||||
|
for (String key : keys) {
|
||||||
|
if (map.containsKey(key)) {
|
||||||
|
return map.get(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -276,20 +276,22 @@ public class AgentServiceImpl extends ServiceImpl<AgentMapper, Agent> implements
|
|||||||
summary.put("bindingId", binding.getId());
|
summary.put("bindingId", binding.getId());
|
||||||
summary.put("toolType", binding.getToolType());
|
summary.put("toolType", binding.getToolType());
|
||||||
summary.put("targetId", binding.getTargetId());
|
summary.put("targetId", binding.getTargetId());
|
||||||
summary.put("toolName", binding.getToolName());
|
|
||||||
summary.put("enabled", Boolean.TRUE.equals(binding.getEnabled()));
|
summary.put("enabled", Boolean.TRUE.equals(binding.getEnabled()));
|
||||||
summary.put("hitlEnabled", Boolean.TRUE.equals(binding.getHitlEnabled()));
|
summary.put("hitlEnabled", Boolean.TRUE.equals(binding.getHitlEnabled()));
|
||||||
summary.put("hitlConfigJson", binding.getHitlConfigJson());
|
summary.put("hitlConfigJson", binding.getHitlConfigJson());
|
||||||
summary.put("sortNo", binding.getSortNo());
|
summary.put("sortNo", binding.getSortNo());
|
||||||
if ("WORKFLOW".equalsIgnoreCase(binding.getToolType())) {
|
if ("WORKFLOW".equalsIgnoreCase(binding.getToolType())) {
|
||||||
|
summary.put("toolName", binding.getToolName());
|
||||||
Workflow workflow = workflowService.getById(binding.getTargetId());
|
Workflow workflow = workflowService.getById(binding.getTargetId());
|
||||||
summary.put("title", workflow == null ? null : workflow.getTitle());
|
summary.put("title", workflow == null ? null : workflow.getTitle());
|
||||||
} else if ("PLUGIN".equalsIgnoreCase(binding.getToolType())) {
|
} else if ("PLUGIN".equalsIgnoreCase(binding.getToolType())) {
|
||||||
|
summary.put("toolName", binding.getToolName());
|
||||||
PluginItem pluginItem = pluginItemService.getById(binding.getTargetId());
|
PluginItem pluginItem = pluginItemService.getById(binding.getTargetId());
|
||||||
summary.put("title", pluginItem == null ? null : pluginItem.getName());
|
summary.put("title", pluginItem == null ? null : pluginItem.getName());
|
||||||
} else {
|
} else {
|
||||||
Mcp mcp = mcpService.getById(binding.getTargetId());
|
Mcp mcp = mcpService.getById(binding.getTargetId());
|
||||||
summary.put("title", mcp == null ? null : mcp.getTitle());
|
summary.put("title", mcp == null ? null : mcp.getTitle());
|
||||||
|
summary.put("tools", mcp == null || mcp.getTools() == null ? List.of() : mcp.getTools());
|
||||||
}
|
}
|
||||||
return summary;
|
return summary;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandAction;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandConsumer;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandMessage;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandResultRegistry;
|
||||||
|
import tech.easyflow.agent.runtime.AgentRunService;
|
||||||
|
import tech.easyflow.common.mq.config.MQProperties;
|
||||||
|
import tech.easyflow.common.mq.core.MQMessage;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link AgentRuntimeCommandConsumer} 回归测试。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeCommandConsumerTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证消费者只处理发给当前节点的命令。
|
||||||
|
*
|
||||||
|
* @throws Exception 消息序列化异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void consumerShouldHandleOnlyCurrentNodeCommand() throws Exception {
|
||||||
|
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||||
|
properties.setInstanceId("node-a");
|
||||||
|
MQProperties mqProperties = new MQProperties();
|
||||||
|
mqProperties.getRedis().setChatPersistShardCount(4);
|
||||||
|
RecordingAgentRunService service = new RecordingAgentRunService();
|
||||||
|
RecordingCommandResultRegistry resultRegistry = new RecordingCommandResultRegistry();
|
||||||
|
AgentRuntimeCommandConsumer consumer =
|
||||||
|
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||||
|
|
||||||
|
consumer.handle(List.of(message(command("cmd-1", "node-b")), message(command("cmd-2", "node-a"))));
|
||||||
|
|
||||||
|
Assert.assertEquals(1, service.approveCount);
|
||||||
|
Assert.assertEquals("request-cmd-2", service.lastRequestId);
|
||||||
|
Assert.assertEquals(4, consumer.subscription().getShardCount());
|
||||||
|
Assert.assertFalse(consumer.subscription().isBatchEnabled());
|
||||||
|
Assert.assertEquals("cmd-2", resultRegistry.lastSuccessCommandId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 owner 本机执行失败时写入失败结果,避免 MQ 重试重复消费一次性 token。
|
||||||
|
*
|
||||||
|
* @throws Exception 消息序列化异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void consumerShouldMarkFailureWhenLocalRuntimeFails() throws Exception {
|
||||||
|
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||||
|
properties.setInstanceId("node-a");
|
||||||
|
MQProperties mqProperties = new MQProperties();
|
||||||
|
FailingAgentRunService service = new FailingAgentRunService();
|
||||||
|
RecordingCommandResultRegistry resultRegistry = new RecordingCommandResultRegistry();
|
||||||
|
AgentRuntimeCommandConsumer consumer =
|
||||||
|
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||||
|
|
||||||
|
consumer.handle(List.of(message(command("cmd-1", "node-a"))));
|
||||||
|
|
||||||
|
Assert.assertEquals("cmd-1", resultRegistry.lastFailureCommandId);
|
||||||
|
Assert.assertEquals("runtime missing", resultRegistry.lastFailureMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证成功结果写入失败不会再次执行或改写为失败结果。
|
||||||
|
*
|
||||||
|
* @throws Exception 消息序列化异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void consumerShouldNotMarkFailureWhenSuccessResultWriteFails() throws Exception {
|
||||||
|
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||||
|
properties.setInstanceId("node-a");
|
||||||
|
MQProperties mqProperties = new MQProperties();
|
||||||
|
RecordingAgentRunService service = new RecordingAgentRunService();
|
||||||
|
FailingSuccessResultRegistry resultRegistry = new FailingSuccessResultRegistry();
|
||||||
|
AgentRuntimeCommandConsumer consumer =
|
||||||
|
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||||
|
|
||||||
|
consumer.handle(List.of(message(command("cmd-1", "node-a"))));
|
||||||
|
|
||||||
|
Assert.assertEquals(1, service.approveCount);
|
||||||
|
Assert.assertNull(resultRegistry.lastFailureCommandId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeCommandMessage command(String commandId, String targetNodeId) {
|
||||||
|
AgentRuntimeCommandMessage command = new AgentRuntimeCommandMessage();
|
||||||
|
command.setCommandId(commandId);
|
||||||
|
command.setRequestId("request-" + commandId);
|
||||||
|
command.setResumeToken("token-" + commandId);
|
||||||
|
command.setAction(AgentRuntimeCommandAction.APPROVE);
|
||||||
|
command.setOperatorId(BigInteger.ONE);
|
||||||
|
command.setUserId("1");
|
||||||
|
command.setTargetNodeId(targetNodeId);
|
||||||
|
return command;
|
||||||
|
}
|
||||||
|
|
||||||
|
private MQMessage message(AgentRuntimeCommandMessage command) throws Exception {
|
||||||
|
MQMessage message = new MQMessage();
|
||||||
|
message.setMessageId(command.getCommandId());
|
||||||
|
message.setBody(new ObjectMapper().writeValueAsString(command));
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class RecordingAgentRunService extends AgentRunService {
|
||||||
|
|
||||||
|
private int approveCount;
|
||||||
|
private String lastRequestId;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||||
|
approveCount++;
|
||||||
|
lastRequestId = requestId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingCommandResultRegistry extends AgentRuntimeCommandResultRegistry {
|
||||||
|
|
||||||
|
private String lastSuccessCommandId;
|
||||||
|
String lastFailureCommandId;
|
||||||
|
private String lastFailureMessage;
|
||||||
|
|
||||||
|
private RecordingCommandResultRegistry() {
|
||||||
|
super(null, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void markSuccess(String commandId) {
|
||||||
|
lastSuccessCommandId = commandId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void markFailure(String commandId, String message) {
|
||||||
|
lastFailureCommandId = commandId;
|
||||||
|
lastFailureMessage = message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class FailingAgentRunService extends AgentRunService {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||||
|
throw new RuntimeException("runtime missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class FailingSuccessResultRegistry extends RecordingCommandResultRegistry {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void markSuccess(String commandId) {
|
||||||
|
super.markSuccess(commandId);
|
||||||
|
throw new RuntimeException("redis down");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.data.redis.core.ValueOperations;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandResult;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandResultRegistry;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link AgentRuntimeCommandResultRegistry} 回归测试。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeCommandResultRegistryTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证成功结果可被等待方读取。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void waitForResultShouldReturnSuccessResult() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
Mockito.when(valueOperations.get("easyflow:agent:runtime:command-result:cmd-1"))
|
||||||
|
.thenReturn("{\"success\":true,\"message\":\"OK\"}");
|
||||||
|
AgentRuntimeCommandResultRegistry registry = registry(redisTemplate);
|
||||||
|
|
||||||
|
AgentRuntimeCommandResult result = registry.waitForResult("cmd-1");
|
||||||
|
|
||||||
|
Assert.assertTrue(result.isSuccess());
|
||||||
|
Assert.assertEquals("OK", result.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证写入失败结果时使用配置的 TTL。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void markFailureShouldWriteResultWithTtl() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
AgentRuntimeProperties properties = properties();
|
||||||
|
AgentRuntimeCommandResultRegistry registry =
|
||||||
|
new AgentRuntimeCommandResultRegistry(redisTemplate, new ObjectMapper(), properties);
|
||||||
|
|
||||||
|
registry.markFailure("cmd-1", "failed");
|
||||||
|
|
||||||
|
Mockito.verify(valueOperations).set(
|
||||||
|
ArgumentMatchers.eq("easyflow:agent:runtime:command-result:cmd-1"),
|
||||||
|
ArgumentMatchers.contains("\"success\":false"),
|
||||||
|
ArgumentMatchers.eq(properties.getCommandResultTtl()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证等待超时时抛出明确业务异常。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void waitForResultShouldThrowBusinessExceptionWhenTimeout() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
Mockito.when(valueOperations.get(ArgumentMatchers.anyString())).thenReturn(null);
|
||||||
|
AgentRuntimeCommandResultRegistry registry = registry(redisTemplate);
|
||||||
|
|
||||||
|
BusinessException exception = Assert.assertThrows(
|
||||||
|
BusinessException.class,
|
||||||
|
() -> registry.waitForResult("cmd-1"));
|
||||||
|
|
||||||
|
Assert.assertEquals("Agent 运行节点响应超时,请稍后重试", exception.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeCommandResultRegistry registry(StringRedisTemplate redisTemplate) {
|
||||||
|
return new AgentRuntimeCommandResultRegistry(redisTemplate, new ObjectMapper(), properties());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeProperties properties() {
|
||||||
|
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||||
|
properties.setCommandResultTimeout(Duration.ofMillis(10));
|
||||||
|
properties.setCommandResultTtl(Duration.ofMinutes(5));
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package tech.easyflow.agent.distributed;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.data.redis.core.ValueOperations;
|
||||||
|
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link AgentRuntimeRouteRegistry} 回归测试。
|
||||||
|
*/
|
||||||
|
public class AgentRuntimeRouteRegistryTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证注册运行态和恢复令牌时写入 Redis 路由。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void registerShouldWriteRunAndTokenRoutes() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
AgentRuntimeProperties properties = properties("node-a");
|
||||||
|
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties);
|
||||||
|
|
||||||
|
registry.registerRun("request-1");
|
||||||
|
registry.registerResumeToken("request-1", "token-1");
|
||||||
|
|
||||||
|
Mockito.verify(valueOperations).set(
|
||||||
|
ArgumentMatchers.eq("easyflow:agent:runtime:request:request-1"),
|
||||||
|
ArgumentMatchers.contains("\"nodeId\":\"node-a\""),
|
||||||
|
ArgumentMatchers.eq(Duration.ofHours(24)));
|
||||||
|
Mockito.verify(valueOperations).set(
|
||||||
|
"easyflow:agent:runtime:resume-token:token-1", "request-1", Duration.ofHours(24));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证运行结束时清理 Redis 路由。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void removeShouldDeleteRunAndTokenRoutes() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties("node-a"));
|
||||||
|
|
||||||
|
registry.removeRun("request-1");
|
||||||
|
registry.removeResumeToken("token-1");
|
||||||
|
|
||||||
|
Mockito.verify(redisTemplate).delete("easyflow:agent:runtime:request:request-1");
|
||||||
|
Mockito.verify(redisTemplate).delete("easyflow:agent:runtime:resume-token:token-1");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证查询 owner 节点和 token 反查请求 ID。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void findShouldReadRoutes() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
Mockito.when(valueOperations.get(ArgumentMatchers.eq("easyflow:agent:runtime:request:request-1")))
|
||||||
|
.thenReturn("{\"nodeId\":\"node-a\",\"bootId\":\"boot-a\"}");
|
||||||
|
Mockito.when(valueOperations.get(ArgumentMatchers.eq("easyflow:agent:runtime:resume-token:token-1")))
|
||||||
|
.thenReturn("request-1");
|
||||||
|
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties("node-a"));
|
||||||
|
|
||||||
|
Assert.assertEquals("node-a", registry.findOwnerNode("request-1"));
|
||||||
|
Assert.assertEquals("boot-a", registry.findOwnerRoute("request-1").getBootId());
|
||||||
|
Assert.assertEquals("request-1", registry.findRequestIdByResumeToken("token-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证节点心跳写入和存活查询。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void heartbeatShouldWriteAndReadNodeAliveState() {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||||
|
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||||
|
AgentRuntimeProperties properties = properties("node-a");
|
||||||
|
Mockito.when(valueOperations.get("easyflow:agent:runtime:node:node-a")).thenReturn(properties.getBootId());
|
||||||
|
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties);
|
||||||
|
|
||||||
|
registry.heartbeat(Duration.ofSeconds(90));
|
||||||
|
|
||||||
|
Mockito.verify(valueOperations).set("easyflow:agent:runtime:node:node-a", properties.getBootId(), Duration.ofSeconds(90));
|
||||||
|
Assert.assertTrue(registry.isNodeAlive("node-a"));
|
||||||
|
Assert.assertEquals(properties.getBootId(), registry.currentNodeBootId("node-a"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeProperties properties(String instanceId) {
|
||||||
|
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||||
|
properties.setInstanceId(instanceId);
|
||||||
|
properties.setRouteTtl(Duration.ofHours(24));
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentRuntimeRouteRegistry registry(StringRedisTemplate redisTemplate, AgentRuntimeProperties properties) {
|
||||||
|
return new AgentRuntimeRouteRegistry(redisTemplate, properties, new ObjectMapper());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.agent.entity.Agent;
|
||||||
|
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||||
|
import tech.easyflow.ai.entity.DocumentCollection;
|
||||||
|
import tech.easyflow.ai.enums.PublishStatus;
|
||||||
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
|
import tech.easyflow.common.entity.LoginAccount;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
import tech.easyflow.system.enums.CategoryResourceType;
|
||||||
|
import tech.easyflow.system.enums.ResourceAction;
|
||||||
|
import tech.easyflow.system.permission.resource.VisibilityResource;
|
||||||
|
import tech.easyflow.system.service.ResourceAccessService;
|
||||||
|
|
||||||
|
import java.lang.reflect.Proxy;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 聊天临时能力编排服务测试。
|
||||||
|
*/
|
||||||
|
public class AgentChatCapabilityServiceTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void applyShouldAppendPublishedKnowledgeAndSkipBoundDuplicate() {
|
||||||
|
DocumentCollectionService documentService = documentService(
|
||||||
|
knowledge(1, PublishStatus.PUBLISHED),
|
||||||
|
knowledge(2, PublishStatus.PUBLISHED)
|
||||||
|
);
|
||||||
|
AgentChatCapabilityService service = new AgentChatCapabilityService(
|
||||||
|
documentService,
|
||||||
|
new AllowResourceAccessService(),
|
||||||
|
new ObjectMapper()
|
||||||
|
);
|
||||||
|
Agent agent = agentWithBoundKnowledge(1);
|
||||||
|
|
||||||
|
AgentChatCapabilityService.AgentChatCapabilityResolution resolution = service.apply(
|
||||||
|
agent,
|
||||||
|
List.of(capability(1, 2, 2)),
|
||||||
|
account()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<AgentKnowledgeBinding> bindings = resolution.agent().getKnowledgeBindings();
|
||||||
|
Assert.assertEquals(List.of(BigInteger.ONE, BigInteger.valueOf(2)), resolution.extraKnowledgeIds());
|
||||||
|
Assert.assertEquals(2, bindings.size());
|
||||||
|
Assert.assertEquals(BigInteger.ONE, bindings.get(0).getKnowledgeId());
|
||||||
|
Assert.assertEquals(BigInteger.valueOf(2), bindings.get(1).getKnowledgeId());
|
||||||
|
Assert.assertEquals("HYBRID", bindings.get(1).getRetrievalMode());
|
||||||
|
Assert.assertTrue(bindings.get(1).getEnabled());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void applyShouldAppendWhenBoundKnowledgeIsDisabled() {
|
||||||
|
DocumentCollectionService documentService = documentService(knowledge(2, PublishStatus.PUBLISHED));
|
||||||
|
AgentChatCapabilityService service = new AgentChatCapabilityService(
|
||||||
|
documentService,
|
||||||
|
new AllowResourceAccessService(),
|
||||||
|
new ObjectMapper()
|
||||||
|
);
|
||||||
|
Agent agent = agentWithBoundKnowledge(2);
|
||||||
|
agent.getKnowledgeBindings().get(0).setEnabled(false);
|
||||||
|
|
||||||
|
AgentChatCapabilityService.AgentChatCapabilityResolution resolution = service.apply(
|
||||||
|
agent,
|
||||||
|
List.of(capability(2)),
|
||||||
|
account()
|
||||||
|
);
|
||||||
|
|
||||||
|
List<AgentKnowledgeBinding> bindings = resolution.agent().getKnowledgeBindings();
|
||||||
|
Assert.assertEquals(2, bindings.size());
|
||||||
|
Assert.assertFalse(bindings.get(0).getEnabled());
|
||||||
|
Assert.assertTrue(bindings.get(1).getEnabled());
|
||||||
|
Assert.assertEquals(BigInteger.valueOf(2), bindings.get(1).getKnowledgeId());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void resolveKnowledgeIdsShouldRejectTooManyItems() {
|
||||||
|
AgentChatCapabilityService service = new AgentChatCapabilityService(
|
||||||
|
documentService(),
|
||||||
|
new AllowResourceAccessService(),
|
||||||
|
new ObjectMapper()
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertThrows(BusinessException.class,
|
||||||
|
() -> service.resolveKnowledgeIds(List.of(capability(1, 2, 3, 4))));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void applyShouldRejectUnpublishedKnowledge() {
|
||||||
|
DocumentCollectionService documentService = documentService(knowledge(2, PublishStatus.OFFLINE));
|
||||||
|
AgentChatCapabilityService service = new AgentChatCapabilityService(
|
||||||
|
documentService,
|
||||||
|
new AllowResourceAccessService(),
|
||||||
|
new ObjectMapper()
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertThrows(BusinessException.class,
|
||||||
|
() -> service.apply(agentWithBoundKnowledge(1), List.of(capability(2)), account()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void applyShouldRejectUnauthorizedKnowledge() {
|
||||||
|
DocumentCollectionService documentService = documentService(knowledge(2, PublishStatus.PUBLISHED));
|
||||||
|
AgentChatCapabilityService service = new AgentChatCapabilityService(
|
||||||
|
documentService,
|
||||||
|
new DenyResourceAccessService(),
|
||||||
|
new ObjectMapper()
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertThrows(BusinessException.class,
|
||||||
|
() -> service.apply(agentWithBoundKnowledge(1), List.of(capability(2)), account()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Agent agentWithBoundKnowledge(int knowledgeId) {
|
||||||
|
Agent agent = new Agent();
|
||||||
|
agent.setId(BigInteger.TEN);
|
||||||
|
agent.setTenantId(BigInteger.ONE);
|
||||||
|
AgentKnowledgeBinding binding = new AgentKnowledgeBinding();
|
||||||
|
binding.setAgentId(agent.getId());
|
||||||
|
binding.setKnowledgeId(BigInteger.valueOf(knowledgeId));
|
||||||
|
binding.setRetrievalMode("HYBRID");
|
||||||
|
binding.setEnabled(true);
|
||||||
|
agent.setKnowledgeBindings(List.of(binding));
|
||||||
|
return agent;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static AgentChatCapability capability(int... ids) {
|
||||||
|
AgentChatCapability capability = new AgentChatCapability();
|
||||||
|
capability.setType("KNOWLEDGE");
|
||||||
|
capability.setResourceIds(java.util.Arrays.stream(ids)
|
||||||
|
.mapToObj(BigInteger::valueOf)
|
||||||
|
.toList());
|
||||||
|
return capability;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static DocumentCollection knowledge(int id, PublishStatus status) {
|
||||||
|
DocumentCollection collection = new DocumentCollection();
|
||||||
|
collection.setId(BigInteger.valueOf(id));
|
||||||
|
collection.setTitle("知识库" + id);
|
||||||
|
collection.setPublishStatus(status.name());
|
||||||
|
return collection;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static LoginAccount account() {
|
||||||
|
LoginAccount account = new LoginAccount();
|
||||||
|
account.setId(BigInteger.valueOf(100));
|
||||||
|
account.setTenantId(BigInteger.ONE);
|
||||||
|
return account;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static DocumentCollectionService documentService(DocumentCollection... collections) {
|
||||||
|
Map<BigInteger, DocumentCollection> collectionMap = new HashMap<>();
|
||||||
|
for (DocumentCollection collection : collections) {
|
||||||
|
collectionMap.put(collection.getId(), collection);
|
||||||
|
}
|
||||||
|
return (DocumentCollectionService) Proxy.newProxyInstance(
|
||||||
|
AgentChatCapabilityServiceTest.class.getClassLoader(),
|
||||||
|
new Class<?>[]{DocumentCollectionService.class},
|
||||||
|
(proxy, method, args) -> switch (method.getName()) {
|
||||||
|
case "getById" -> collectionMap.get(new BigInteger(String.valueOf(args[0])));
|
||||||
|
case "toPublishedView" -> args[0];
|
||||||
|
case "listByIds" -> ((java.util.Collection<?>) args[0]).stream()
|
||||||
|
.map(id -> collectionMap.get(new BigInteger(String.valueOf(id))))
|
||||||
|
.filter(java.util.Objects::nonNull)
|
||||||
|
.toList();
|
||||||
|
default -> throw new UnsupportedOperationException(method.getName());
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class AllowResourceAccessService implements ResourceAccessService {
|
||||||
|
@Override public boolean canAccess(CategoryResourceType resourceType, VisibilityResource resource, ResourceAction action) { return true; }
|
||||||
|
@Override public boolean canAccess(LoginAccount loginAccount, CategoryResourceType resourceType, VisibilityResource resource, ResourceAction action) { return true; }
|
||||||
|
@Override public void assertAccess(CategoryResourceType resourceType, VisibilityResource resource, ResourceAction action, String message) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class DenyResourceAccessService extends AllowResourceAccessService {
|
||||||
|
@Override public void assertAccess(CategoryResourceType resourceType, VisibilityResource resource, ResourceAction action, String message) {
|
||||||
|
throw new BusinessException(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.AgentDefinition;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||||
|
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.agent.entity.Agent;
|
||||||
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolRuntimeCompiler;
|
||||||
|
import tech.easyflow.ai.entity.Mcp;
|
||||||
|
import tech.easyflow.ai.entity.Model;
|
||||||
|
import tech.easyflow.ai.entity.ModelProvider;
|
||||||
|
import tech.easyflow.ai.service.McpService;
|
||||||
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.lang.reflect.Proxy;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent MCP 运行时定义编译测试。
|
||||||
|
*/
|
||||||
|
public class AgentDefinitionCompilerMcpTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Agent 绑定 MCP 后会编译为 runtime 原生 MCP 声明,并按整个 MCP 暴露工具。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldBuildWholeMcpSpecWithDynamicPrefixAndApproval() throws Exception {
|
||||||
|
BigInteger modelId = BigInteger.valueOf(10L);
|
||||||
|
BigInteger mcpId = BigInteger.valueOf(20L);
|
||||||
|
Model model = model(modelId);
|
||||||
|
Mcp mcp = mcp(mcpId);
|
||||||
|
AgentRuntimeCompiler compiler = new AgentRuntimeCompiler();
|
||||||
|
AgentToolRuntimeCompiler toolCompiler = new AgentToolRuntimeCompiler();
|
||||||
|
setField(compiler, "objectMapper", new com.fasterxml.jackson.databind.ObjectMapper());
|
||||||
|
setField(compiler, "modelService", modelService(model));
|
||||||
|
setField(toolCompiler, "objectMapper", new com.fasterxml.jackson.databind.ObjectMapper());
|
||||||
|
setField(toolCompiler, "mcpService", mcpService(mcp));
|
||||||
|
setField(compiler, "agentToolRuntimeCompiler", toolCompiler);
|
||||||
|
|
||||||
|
Agent agent = agent(modelId, mcpId);
|
||||||
|
|
||||||
|
AgentRuntimeBundle bundle = compiler.compile(agent);
|
||||||
|
AgentDefinition definition = bundle.getDefinition();
|
||||||
|
|
||||||
|
Assert.assertTrue(definition.getToolSpecs().isEmpty());
|
||||||
|
Assert.assertTrue(bundle.getToolInvokers().isEmpty());
|
||||||
|
Assert.assertEquals(1, definition.getMcpSpecs().size());
|
||||||
|
McpSpec spec = definition.getMcpSpecs().get(0);
|
||||||
|
Assert.assertEquals("mcp_20", spec.getName());
|
||||||
|
Assert.assertEquals(McpTransportType.STDIO, spec.getTransportType());
|
||||||
|
Assert.assertEquals("npx", spec.getCommand());
|
||||||
|
Assert.assertEquals(List.of("-y", "@modelcontextprotocol/server-everything"), spec.getArgs());
|
||||||
|
Assert.assertTrue(spec.isApprovalRequired());
|
||||||
|
Assert.assertEquals("mcp_20_", spec.getToolNamePrefix());
|
||||||
|
Assert.assertTrue(spec.getToolAliases().isEmpty());
|
||||||
|
Assert.assertTrue(spec.getEnableTools().isEmpty());
|
||||||
|
Assert.assertEquals(AgentToolType.MCP.name(), spec.getMetadata().get("toolType"));
|
||||||
|
Assert.assertEquals(String.valueOf(mcpId), spec.getMetadata().get("mcpId"));
|
||||||
|
Assert.assertEquals("everything", spec.getMetadata().get("serverName"));
|
||||||
|
Assert.assertTrue(spec.getToolApprovalRequests().isEmpty());
|
||||||
|
Assert.assertEquals("确认调用 MCP 工具?", spec.getApprovalRequest().getApprovalPrompt());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Agent agent(BigInteger modelId, BigInteger mcpId) {
|
||||||
|
AgentToolBinding binding = new AgentToolBinding();
|
||||||
|
binding.setToolType(AgentToolType.MCP.name());
|
||||||
|
binding.setTargetId(mcpId);
|
||||||
|
binding.setEnabled(true);
|
||||||
|
binding.setHitlEnabled(true);
|
||||||
|
binding.setHitlConfigJson(Map.of("prompt", "确认调用 MCP 工具?"));
|
||||||
|
|
||||||
|
Agent agent = new Agent();
|
||||||
|
agent.setId(BigInteger.valueOf(1L));
|
||||||
|
agent.setName("MCP Agent");
|
||||||
|
agent.setModelId(modelId);
|
||||||
|
agent.setToolBindings(List.of(binding));
|
||||||
|
return agent;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Model model(BigInteger modelId) {
|
||||||
|
ModelProvider provider = new ModelProvider();
|
||||||
|
provider.setProviderType("openai");
|
||||||
|
provider.setProviderName("OpenAI");
|
||||||
|
Model model = new Model();
|
||||||
|
model.setId(modelId);
|
||||||
|
model.setModelProvider(provider);
|
||||||
|
model.setModelName("gpt-test");
|
||||||
|
model.setEndpoint("https://example.com");
|
||||||
|
model.setRequestPath("/v1/chat/completions");
|
||||||
|
model.setApiKey("test-key");
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mcp mcp(BigInteger mcpId) {
|
||||||
|
Mcp mcp = new Mcp();
|
||||||
|
mcp.setId(mcpId);
|
||||||
|
mcp.setTitle("Everything");
|
||||||
|
mcp.setDescription("MCP Everything");
|
||||||
|
mcp.setApprovalRequired(true);
|
||||||
|
mcp.setStatus(true);
|
||||||
|
mcp.setConfigJson("""
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"everything": {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-everything"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
return mcp;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ModelService modelService(Model model) {
|
||||||
|
return (ModelService) Proxy.newProxyInstance(
|
||||||
|
ModelService.class.getClassLoader(),
|
||||||
|
new Class<?>[]{ModelService.class},
|
||||||
|
(proxy, method, args) -> "getModelInstance".equals(method.getName()) ? model : defaultValue(method.getReturnType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private McpService mcpService(Mcp mcp) {
|
||||||
|
return (McpService) Proxy.newProxyInstance(
|
||||||
|
McpService.class.getClassLoader(),
|
||||||
|
new Class<?>[]{McpService.class},
|
||||||
|
(proxy, method, args) -> "getById".equals(method.getName()) ? mcp : defaultValue(method.getReturnType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object defaultValue(Class<?> type) {
|
||||||
|
if (type == boolean.class) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (type == int.class || type == long.class || type == short.class || type == byte.class) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (type == double.class || type == float.class) {
|
||||||
|
return 0D;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(target, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,25 @@
|
|||||||
package tech.easyflow.agent.runtime;
|
package tech.easyflow.agent.runtime;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.AgentInitRequest;
|
||||||
|
import com.easyagents.agent.runtime.AgentRuntime;
|
||||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||||
import com.easyagents.agent.runtime.event.AgentRuntimeEventType;
|
import com.easyagents.agent.runtime.event.AgentRuntimeEventType;
|
||||||
import com.easyagents.agent.runtime.message.AgentKnowledgeReference;
|
import com.easyagents.agent.runtime.message.AgentKnowledgeReference;
|
||||||
import com.easyagents.agent.runtime.message.AgentMessage;
|
import com.easyagents.agent.runtime.message.AgentMessage;
|
||||||
import com.easyagents.agent.runtime.message.AgentMessageRole;
|
import com.easyagents.agent.runtime.message.AgentMessageRole;
|
||||||
|
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||||
|
import com.easyagents.agent.runtime.persistence.session.memory.InMemoryAgentSessionStore;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.agent.entity.AgentHitlPending;
|
||||||
import tech.easyflow.agent.entity.Agent;
|
import tech.easyflow.agent.entity.Agent;
|
||||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeCommandProducer;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRoute;
|
||||||
|
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||||
|
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||||
|
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||||
import tech.easyflow.common.entity.LoginAccount;
|
import tech.easyflow.common.entity.LoginAccount;
|
||||||
@@ -402,14 +412,283 @@ public class AgentRunServiceDraftAndHitlTest {
|
|||||||
|
|
||||||
Exception thrown = Assert.assertThrows(Exception.class, () -> invoke(service, "run",
|
Exception thrown = Assert.assertThrows(Exception.class, () -> invoke(service, "run",
|
||||||
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class,
|
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class,
|
||||||
String.class, ChatRuntimeContext.class, boolean.class},
|
String.class, ChatRuntimeContext.class, boolean.class, AgentSessionStore.class},
|
||||||
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true));
|
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true,
|
||||||
|
new InMemoryAgentSessionStore()));
|
||||||
|
|
||||||
Assert.assertTrue(rootCause(thrown) instanceof BusinessException);
|
Assert.assertTrue(rootCause(thrown) instanceof BusinessException);
|
||||||
Assert.assertEquals(0, chatRuntimeManager.prepareSessionCount);
|
Assert.assertEquals(0, chatRuntimeManager.prepareSessionCount);
|
||||||
Assert.assertEquals(0, chatRuntimeManager.recordUserMessageCount);
|
Assert.assertEquals(0, chatRuntimeManager.recordUserMessageCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证草稿运行会使用独立 session store,且不会绑定 MySQL session 元信息。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void startRuntimeShouldUseDraftSessionStoreWithoutBindingMysqlSession() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingAgentRuntimeCompiler compiler = new RecordingAgentRuntimeCompiler();
|
||||||
|
RecordingAgentRuntime runtime = new RecordingAgentRuntime();
|
||||||
|
RecordingAgentRuntimeFactory runtimeFactory = new RecordingAgentRuntimeFactory(runtime);
|
||||||
|
AgentSessionStore draftStore = new InMemoryAgentSessionStore();
|
||||||
|
setField(service, "agentRuntimeCompiler", compiler);
|
||||||
|
setField(service, "agentRuntimeFactory", runtimeFactory);
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
|
||||||
|
Agent agent = new Agent();
|
||||||
|
agent.setId(BigInteger.valueOf(100));
|
||||||
|
invoke(service, "startRuntime",
|
||||||
|
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class, String.class,
|
||||||
|
ChatRuntimeContext.class, ChatSseEmitter.class, boolean.class, AgentSessionStore.class,
|
||||||
|
AgentRunLock.Handle.class},
|
||||||
|
agent, "你好", "request-draft", "trace-draft", "agent-draft-100", "AGENT_DRAFT",
|
||||||
|
chatContext(), new RecordingChatSseEmitter(), false, draftStore, null);
|
||||||
|
|
||||||
|
Assert.assertSame(draftStore, runtime.initRequest.getSessionStore());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证草稿事件不会写运行事件表,正式事件仍会记录。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void handleRuntimeEventShouldOnlyPersistEventsForFormalChat() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
RecordingAgentRunEventRecorder recorder = new RecordingAgentRunEventRecorder();
|
||||||
|
setField(service, "agentRunEventRecorder", recorder);
|
||||||
|
AgentRuntimeEvent draftEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||||
|
draftEvent.getPayload().put("toolName", "search");
|
||||||
|
|
||||||
|
invoke(service, "handleRuntimeEvent",
|
||||||
|
runtimeEventParameterTypes(),
|
||||||
|
draftEvent, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||||
|
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||||
|
|
||||||
|
Assert.assertEquals(0, recorder.recordCount);
|
||||||
|
|
||||||
|
AgentRuntimeEvent formalEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||||
|
formalEvent.getPayload().put("toolName", "search");
|
||||||
|
invoke(service, "handleRuntimeEvent",
|
||||||
|
runtimeEventParameterTypes(),
|
||||||
|
formalEvent, "request-formal", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||||
|
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), true);
|
||||||
|
|
||||||
|
Assert.assertEquals(1, recorder.recordCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证草稿工具审批只注册内存恢复令牌,不写 HITL pending 表。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void draftToolApprovalShouldNotPersistPending() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
AgentRunRegistry registry = new AgentRunRegistry();
|
||||||
|
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||||
|
setField(service, "agentRunRegistry", registry);
|
||||||
|
setField(service, "agentHitlPendingService", pendingService);
|
||||||
|
registry.register(runContext("request-draft", "agent-draft-tool", false));
|
||||||
|
AgentRuntimeEvent event = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED);
|
||||||
|
event.getPayload().put("resumeToken", "token-draft");
|
||||||
|
|
||||||
|
invoke(service, "handleRuntimeEvent",
|
||||||
|
runtimeEventParameterTypes(),
|
||||||
|
event, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||||
|
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||||
|
|
||||||
|
Assert.assertTrue(registry.containsResumeTarget("request-draft", "token-draft"));
|
||||||
|
Assert.assertEquals(0, pendingService.recordApprovalRequiredCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证草稿审批恢复不执行 pending 表消费,正式审批仍执行。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldSkipPendingConsumeOnlyForDraftRun() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
AgentRunRegistry registry = new AgentRunRegistry();
|
||||||
|
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||||
|
setField(service, "agentRunRegistry", registry);
|
||||||
|
setField(service, "agentHitlPendingService", pendingService);
|
||||||
|
|
||||||
|
registry.register(runContext("request-draft-approve", "agent-draft-approve", false));
|
||||||
|
registry.registerResumeToken("request-draft-approve", "token-draft-approve");
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
"request-draft-approve", "token-draft-approve", BigInteger.ONE, "1");
|
||||||
|
|
||||||
|
Assert.assertEquals(0, pendingService.approveCount);
|
||||||
|
|
||||||
|
registry.register(runContext("request-formal-approve", "session-formal-approve", true));
|
||||||
|
registry.registerResumeToken("request-formal-approve", "token-formal-approve");
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
"request-formal-approve", "token-formal-approve", BigInteger.ONE, "1");
|
||||||
|
|
||||||
|
Assert.assertEquals(1, pendingService.approveCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证本机存在恢复目标时不投递远程命令。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldNotDispatchRemoteWhenLocalRuntimeExists() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
AgentRunRegistry registry = new AgentRunRegistry();
|
||||||
|
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||||
|
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-a");
|
||||||
|
RecordingCommandProducer commandProducer = new RecordingCommandProducer();
|
||||||
|
setField(service, "agentRunRegistry", registry);
|
||||||
|
setField(service, "agentHitlPendingService", pendingService);
|
||||||
|
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||||
|
setField(service, "agentRuntimeCommandProducer", commandProducer);
|
||||||
|
|
||||||
|
registry.register(runContext("request-local-approve", "session-local-approve", true));
|
||||||
|
registry.registerResumeToken("request-local-approve", "token-local-approve");
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
"request-local-approve", "token-local-approve", BigInteger.ONE, "1");
|
||||||
|
|
||||||
|
Assert.assertEquals(1, pendingService.approveCount);
|
||||||
|
Assert.assertEquals(0, commandProducer.approveCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证本机无运行态但 Redis owner 存在时投递远程命令。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldDispatchRemoteWhenOwnerIsRemoteNode() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||||
|
routeRegistry.requestIdByToken = "request-remote-approve";
|
||||||
|
routeRegistry.ownerNode = "node-a";
|
||||||
|
routeRegistry.ownerBootId = "boot-a";
|
||||||
|
routeRegistry.currentOwnerBootId = "boot-a";
|
||||||
|
routeRegistry.nodeAlive = true;
|
||||||
|
RecordingCommandProducer commandProducer = new RecordingCommandProducer();
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||||
|
setField(service, "agentRuntimeCommandProducer", commandProducer);
|
||||||
|
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
null, "token-remote-approve", BigInteger.ONE, "1");
|
||||||
|
|
||||||
|
Assert.assertEquals(1, commandProducer.approveCount);
|
||||||
|
Assert.assertEquals("node-a", commandProducer.lastTargetNodeId);
|
||||||
|
Assert.assertEquals("request-remote-approve", commandProducer.lastRequestId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 owner 缺失时明确失败。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldFailWhenOwnerRouteMissing() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||||
|
routeRegistry.requestIdByToken = "request-missing-owner";
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||||
|
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||||
|
|
||||||
|
try {
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
null, "token-missing-owner", BigInteger.ONE, "1");
|
||||||
|
Assert.fail("expected BusinessException");
|
||||||
|
} catch (Exception e) {
|
||||||
|
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 owner 重启后启动代不匹配会明确失败。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldFailWhenOwnerBootIdChanged() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||||
|
routeRegistry.requestIdByToken = "request-restarted-owner";
|
||||||
|
routeRegistry.ownerNode = "node-a";
|
||||||
|
routeRegistry.ownerBootId = "boot-old";
|
||||||
|
routeRegistry.currentOwnerBootId = "boot-new";
|
||||||
|
routeRegistry.nodeAlive = true;
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||||
|
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||||
|
|
||||||
|
try {
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
null, "token-restarted-owner", BigInteger.ONE, "1");
|
||||||
|
Assert.fail("expected BusinessException");
|
||||||
|
} catch (Exception e) {
|
||||||
|
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 owner 路由存在但节点心跳缺失时明确失败。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void approveShouldFailWhenOwnerNodeHeartbeatMissing() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||||
|
routeRegistry.requestIdByToken = "request-offline-owner";
|
||||||
|
routeRegistry.ownerNode = "node-a";
|
||||||
|
routeRegistry.nodeAlive = false;
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||||
|
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||||
|
|
||||||
|
try {
|
||||||
|
invoke(service, "approveRuntime",
|
||||||
|
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||||
|
null, "token-offline-owner", BigInteger.ONE, "1");
|
||||||
|
Assert.fail("expected BusinessException");
|
||||||
|
} catch (Exception e) {
|
||||||
|
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证清理草稿会话只清草稿 store,不触碰 MySQL pending 清理。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射调用失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void clearDraftSessionShouldOnlyDeleteDraftStore() throws Exception {
|
||||||
|
AgentRunService service = new AgentRunService();
|
||||||
|
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||||
|
RecordingAgentSessionStore draftStore = new RecordingAgentSessionStore();
|
||||||
|
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||||
|
setField(service, "agentHitlPendingService", pendingService);
|
||||||
|
setField(service, "draftAgentSessionStore", draftStore);
|
||||||
|
|
||||||
|
invoke(service, "clearDraftSessionInternal",
|
||||||
|
new Class<?>[]{String.class, String.class}, "agent-draft-clear", "1");
|
||||||
|
|
||||||
|
Assert.assertEquals("agent-draft-clear", draftStore.deletedSessionKey);
|
||||||
|
Assert.assertEquals(0, pendingService.deleteByRuntimeSessionIdCount);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 验证正式聊天会在会话准备完成后向前端返回真实会话 ID。
|
* 验证正式聊天会在会话准备完成后向前端返回真实会话 ID。
|
||||||
*
|
*
|
||||||
@@ -530,6 +809,28 @@ public class AgentRunServiceDraftAndHitlTest {
|
|||||||
ChatRuntimeContext.class, AtomicBoolean.class, boolean.class};
|
ChatRuntimeContext.class, AtomicBoolean.class, boolean.class};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AgentRunRegistry.AgentRunContext runContext(String requestId, String sessionId, boolean persistChatlog) {
|
||||||
|
return new AgentRunRegistry.AgentRunContext(
|
||||||
|
requestId,
|
||||||
|
sessionId,
|
||||||
|
new RecordingAgentRuntime(),
|
||||||
|
new RecordingChatSseEmitter(),
|
||||||
|
chatContext(),
|
||||||
|
new StringBuilder(),
|
||||||
|
new ChatAssistantAccumulator(),
|
||||||
|
new AtomicBoolean(false),
|
||||||
|
persistChatlog,
|
||||||
|
new AgentRunRegistry.RunOwner("agent-1", sessionId, "1"),
|
||||||
|
null,
|
||||||
|
event -> {
|
||||||
|
},
|
||||||
|
error -> {
|
||||||
|
},
|
||||||
|
() -> {
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
private ChatRuntimeContext chatContext() {
|
private ChatRuntimeContext chatContext() {
|
||||||
ChatRuntimeContext context = new ChatRuntimeContext();
|
ChatRuntimeContext context = new ChatRuntimeContext();
|
||||||
context.setAssistantId(BigInteger.valueOf(100));
|
context.setAssistantId(BigInteger.valueOf(100));
|
||||||
@@ -598,6 +899,214 @@ public class AgentRunServiceDraftAndHitlTest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentRuntime implements AgentRuntime {
|
||||||
|
|
||||||
|
private AgentInitRequest initRequest;
|
||||||
|
private int resumeCount;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void init(AgentInitRequest request) {
|
||||||
|
initRequest = request;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public reactor.core.publisher.Flux<AgentRuntimeEvent> stream(AgentMessage userMessage) {
|
||||||
|
return reactor.core.publisher.Flux.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public reactor.core.publisher.Flux<AgentRuntimeEvent> resume(com.easyagents.agent.runtime.AgentResumeRequest request) {
|
||||||
|
resumeCount++;
|
||||||
|
return reactor.core.publisher.Flux.empty();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingRouteRegistry extends AgentRuntimeRouteRegistry {
|
||||||
|
|
||||||
|
private final String currentNodeId;
|
||||||
|
private String ownerNode;
|
||||||
|
private String ownerBootId;
|
||||||
|
private String currentOwnerBootId;
|
||||||
|
private String requestIdByToken;
|
||||||
|
private boolean nodeAlive;
|
||||||
|
|
||||||
|
private RecordingRouteRegistry(String currentNodeId) {
|
||||||
|
super(null, null, null);
|
||||||
|
this.currentNodeId = currentNodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String findOwnerNode(String requestId) {
|
||||||
|
return ownerNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentRuntimeRoute findOwnerRoute(String requestId) {
|
||||||
|
AgentRuntimeRoute route = new AgentRuntimeRoute();
|
||||||
|
route.setNodeId(ownerNode);
|
||||||
|
route.setBootId(ownerBootId);
|
||||||
|
return route;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String findRequestIdByResumeToken(String resumeToken) {
|
||||||
|
return requestIdByToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String currentNodeId() {
|
||||||
|
return currentNodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isNodeAlive(String nodeId) {
|
||||||
|
return nodeAlive;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String currentNodeBootId(String nodeId) {
|
||||||
|
return currentOwnerBootId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingCommandProducer extends AgentRuntimeCommandProducer {
|
||||||
|
|
||||||
|
private int approveCount;
|
||||||
|
private String lastTargetNodeId;
|
||||||
|
private String lastRequestId;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void sendApprove(String targetNodeId,
|
||||||
|
String requestId,
|
||||||
|
String resumeToken,
|
||||||
|
BigInteger operatorId,
|
||||||
|
String userId) {
|
||||||
|
approveCount++;
|
||||||
|
lastTargetNodeId = targetNodeId;
|
||||||
|
lastRequestId = requestId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentRuntimeFactory implements AgentRuntimeFactory {
|
||||||
|
|
||||||
|
private final AgentRuntime runtime;
|
||||||
|
|
||||||
|
private RecordingAgentRuntimeFactory(AgentRuntime runtime) {
|
||||||
|
this.runtime = runtime;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentRuntime create() {
|
||||||
|
return runtime;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentRuntimeCompiler extends AgentRuntimeCompiler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentRuntimeBundle compile(Agent agent) {
|
||||||
|
AgentRuntimeBundle bundle = new AgentRuntimeBundle();
|
||||||
|
bundle.setDefinition(new com.easyagents.agent.runtime.AgentDefinition());
|
||||||
|
return bundle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentRunEventRecorder implements AgentRunEventRecorder {
|
||||||
|
|
||||||
|
private int recordCount;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void record(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||||
|
recordCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentHitlPendingService implements AgentHitlPendingService {
|
||||||
|
|
||||||
|
private int recordApprovalRequiredCount;
|
||||||
|
private int approveCount;
|
||||||
|
private int rejectCount;
|
||||||
|
private int cancelByRequestIdCount;
|
||||||
|
private int deleteByRuntimeSessionIdCount;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||||
|
recordApprovalRequiredCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentHitlPending approve(String resumeToken, BigInteger operatorId) {
|
||||||
|
approveCount++;
|
||||||
|
return new AgentHitlPending();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentHitlPending reject(String resumeToken, BigInteger operatorId, String reason) {
|
||||||
|
rejectCount++;
|
||||||
|
return new AgentHitlPending();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void cancelByRequestId(String requestId, String reason) {
|
||||||
|
cancelByRequestIdCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteByChatSessionId(BigInteger chatSessionId) {
|
||||||
|
// 测试桩无需处理。
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteByRuntimeSessionId(String runtimeSessionId) {
|
||||||
|
deleteByRuntimeSessionIdCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<AgentHitlPending> expirePending(int limit) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class RecordingAgentSessionStore implements AgentSessionStore {
|
||||||
|
|
||||||
|
private String deletedSessionKey;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String sessionKey, String name, io.agentscope.core.state.State state) {
|
||||||
|
// 测试桩无需处理。
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveList(String sessionKey, String name, List<? extends io.agentscope.core.state.State> states) {
|
||||||
|
// 测试桩无需处理。
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <T extends io.agentscope.core.state.State> java.util.Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||||
|
return java.util.Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <T extends io.agentscope.core.state.State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean exists(String sessionKey) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void delete(String sessionKey) {
|
||||||
|
deletedSessionKey = sessionKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public java.util.Set<String> listSessionKeys() {
|
||||||
|
return java.util.Set.of();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 记录 chatlog 写入动作的测试桩。
|
* 记录 chatlog 写入动作的测试桩。
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,195 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolContext;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolCancelRequest;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolObserveRequest;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolResultRequest;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolSubmitResult;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskStatus;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskView;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolExecutionResult;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EasyFlow 异步业务工具基类测试。
|
||||||
|
*/
|
||||||
|
public class AbstractAgentAsyncSubToolsTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 submit、observe、result 与 list 的基础任务生命周期。
|
||||||
|
*
|
||||||
|
* @throws Exception 等待后台执行超时时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void asyncSubToolsShouldSubmitObserveResultAndListCurrentSessionTasks() throws Exception {
|
||||||
|
ThreadPoolTaskExecutor executor = executor();
|
||||||
|
try {
|
||||||
|
InMemoryTaskStore store = new InMemoryTaskStore();
|
||||||
|
TestAsyncSubTools subTools = new TestAsyncSubTools(store, executor);
|
||||||
|
AgentToolContext context = context("session-a");
|
||||||
|
|
||||||
|
AsyncToolSubmitResult submitted = subTools.submit(Map.of("keyword", "hello"), context);
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.PENDING, submitted.getStatus());
|
||||||
|
Assert.assertTrue(submitted.getTaskId().startsWith("async_"));
|
||||||
|
|
||||||
|
AsyncToolTaskView completed = waitTerminal(subTools, submitted.getTaskId(), context);
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.SUCCEEDED, completed.getStatus());
|
||||||
|
Assert.assertEquals(Map.of("echo", "hello"), completed.getResult());
|
||||||
|
Assert.assertTrue(completed.getNextCursor() >= 2L);
|
||||||
|
|
||||||
|
AsyncToolResultRequest resultRequest = new AsyncToolResultRequest();
|
||||||
|
resultRequest.setTaskId(submitted.getTaskId());
|
||||||
|
resultRequest.setCursor(1L);
|
||||||
|
AsyncToolTaskView result = subTools.result(resultRequest, context);
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.SUCCEEDED, result.getStatus());
|
||||||
|
Assert.assertEquals(Map.of("echo", "hello"), result.getResult());
|
||||||
|
Assert.assertFalse(result.getEvents().isEmpty());
|
||||||
|
|
||||||
|
Assert.assertEquals(1, subTools.list(null, context).getTasks().size());
|
||||||
|
Assert.assertTrue(subTools.list(null, context("session-b")).getTasks().isEmpty());
|
||||||
|
|
||||||
|
AsyncToolTaskView crossedSessionView = observe(subTools, submitted.getTaskId(), context("session-b"));
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.FAILED, crossedSessionView.getStatus());
|
||||||
|
Assert.assertEquals("TASK_NOT_FOUND", crossedSessionView.getErrorType());
|
||||||
|
} finally {
|
||||||
|
executor.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证首版取消语义返回明确失败结果。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void cancelShouldReturnUnsupportedFailure() {
|
||||||
|
TestAsyncSubTools subTools = new TestAsyncSubTools(new InMemoryTaskStore(), executor());
|
||||||
|
AsyncToolCancelRequest request = new AsyncToolCancelRequest();
|
||||||
|
request.setTaskId("task-1");
|
||||||
|
|
||||||
|
var result = subTools.cancel(request, context("session-a"));
|
||||||
|
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.FAILED, result.getStatus());
|
||||||
|
Assert.assertEquals("不支持取消", result.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView waitTerminal(TestAsyncSubTools subTools, String taskId, AgentToolContext context) throws Exception {
|
||||||
|
long deadline = System.currentTimeMillis() + 3000L;
|
||||||
|
AsyncToolTaskView view = observe(subTools, taskId, context);
|
||||||
|
while (!Boolean.TRUE.equals(view.getTerminal()) && System.currentTimeMillis() < deadline) {
|
||||||
|
Thread.sleep(20L);
|
||||||
|
view = observe(subTools, taskId, context);
|
||||||
|
}
|
||||||
|
Assert.assertTrue("异步任务应在测试超时前完成", Boolean.TRUE.equals(view.getTerminal()));
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView observe(TestAsyncSubTools subTools, String taskId, AgentToolContext context) {
|
||||||
|
AsyncToolObserveRequest request = new AsyncToolObserveRequest();
|
||||||
|
request.setTaskId(taskId);
|
||||||
|
request.setCursor(0L);
|
||||||
|
return subTools.observe(request, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolContext context(String sessionId) {
|
||||||
|
AgentToolContext context = new AgentToolContext();
|
||||||
|
context.setRequestId("request-1");
|
||||||
|
context.setTraceId("trace-1");
|
||||||
|
context.setSessionId(sessionId);
|
||||||
|
context.setAgentId("agent-1");
|
||||||
|
context.setToolCallId("tool-call-1");
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ThreadPoolTaskExecutor executor() {
|
||||||
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
|
executor.setCorePoolSize(1);
|
||||||
|
executor.setMaxPoolSize(1);
|
||||||
|
executor.setQueueCapacity(4);
|
||||||
|
executor.setThreadNamePrefix("async-sub-tools-test-");
|
||||||
|
executor.initialize();
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class TestAsyncSubTools extends AbstractAgentAsyncSubTools {
|
||||||
|
|
||||||
|
private TestAsyncSubTools(AgentAsyncToolTaskStore taskStore, ThreadPoolTaskExecutor taskExecutor) {
|
||||||
|
super(taskStore, taskExecutor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toolType() {
|
||||||
|
return "PLUGIN";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String toolName() {
|
||||||
|
return "test_tool";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String displayName() {
|
||||||
|
return "测试工具";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected String businessId() {
|
||||||
|
return "business-1";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected AgentToolExecutionResult executeBusiness(Map<String, Object> arguments) {
|
||||||
|
return new AgentToolExecutionResult(Map.of("echo", arguments.get("keyword")), "business-run-1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class InMemoryTaskStore implements AgentAsyncToolTaskStore {
|
||||||
|
|
||||||
|
private final Map<String, AgentAsyncToolTaskRecord> records = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void create(AgentAsyncToolTaskRecord record) {
|
||||||
|
record.setSessionScopedKey(key(record.getSessionId(), record.getTaskId()));
|
||||||
|
records.put(record.getSessionScopedKey(), record);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> get(String sessionId, String taskId) {
|
||||||
|
return Optional.ofNullable(records.get(key(sessionId, taskId)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> update(String sessionId,
|
||||||
|
String taskId,
|
||||||
|
UnaryOperator<AgentAsyncToolTaskRecord> updater) {
|
||||||
|
String key = key(sessionId, taskId);
|
||||||
|
AgentAsyncToolTaskRecord updated = records.computeIfPresent(key,
|
||||||
|
(ignored, existing) -> updater == null ? existing : updater.apply(existing));
|
||||||
|
return Optional.ofNullable(updated);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<AgentAsyncToolTaskRecord> list(String sessionId, AsyncToolTaskStatus status) {
|
||||||
|
List<AgentAsyncToolTaskRecord> result = new ArrayList<>();
|
||||||
|
for (AgentAsyncToolTaskRecord record : records.values()) {
|
||||||
|
if (sessionId.equals(record.getSessionId()) && (status == null || status == record.getStatus())) {
|
||||||
|
result.add(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.sort(Comparator.comparing(AgentAsyncToolTaskRecord::getCreatedAt).reversed());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String key(String sessionId, String taskId) {
|
||||||
|
return sessionId + ":" + taskId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
package tech.easyflow.agent.runtime.asynctool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolContext;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolObserveRequest;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolResultRequest;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolSubmitResult;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskStatus;
|
||||||
|
import com.easyagents.agent.runtime.tool.asynctool.AsyncToolTaskView;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import tech.easyflow.agent.runtime.tool.AgentToolExecutionResult;
|
||||||
|
import tech.easyflow.agent.runtime.tool.PluginToolExecutor;
|
||||||
|
import tech.easyflow.agent.runtime.tool.WorkflowToolExecutor;
|
||||||
|
import tech.easyflow.ai.entity.PluginItem;
|
||||||
|
import tech.easyflow.ai.entity.Workflow;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Workflow 与 Plugin 异步子工具测试。
|
||||||
|
*/
|
||||||
|
public class WorkflowPluginAsyncSubToolsTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Workflow 异步子工具会把业务执行结果保留到任务视图。
|
||||||
|
*
|
||||||
|
* @throws Exception 等待后台执行超时时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void workflowAsyncSubToolsShouldKeepBusinessResultInTaskView() throws Exception {
|
||||||
|
ThreadPoolTaskExecutor executor = executor();
|
||||||
|
try {
|
||||||
|
Map<String, Object> businessResult = Map.of("workflowOutput", "ok");
|
||||||
|
WorkflowAsyncSubTools subTools = new WorkflowAsyncSubTools(workflow(),
|
||||||
|
"workflow_demo",
|
||||||
|
"测试工作流",
|
||||||
|
new StubWorkflowToolExecutor(businessResult),
|
||||||
|
new InMemoryTaskStore(),
|
||||||
|
executor);
|
||||||
|
|
||||||
|
AsyncToolTaskView view = submitAndResult(subTools);
|
||||||
|
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.SUCCEEDED, view.getStatus());
|
||||||
|
Assert.assertEquals(businessResult, view.getResult());
|
||||||
|
Assert.assertEquals("workflow-run-1", view.getPayload().get("businessExecutionId"));
|
||||||
|
} finally {
|
||||||
|
executor.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Plugin 异步子工具会把业务执行结果保留到任务视图。
|
||||||
|
*
|
||||||
|
* @throws Exception 等待后台执行超时时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void pluginAsyncSubToolsShouldKeepBusinessResultInTaskView() throws Exception {
|
||||||
|
ThreadPoolTaskExecutor executor = executor();
|
||||||
|
try {
|
||||||
|
Map<String, Object> businessResult = Map.of("pluginOutput", List.of("a", "b"));
|
||||||
|
PluginAsyncSubTools subTools = new PluginAsyncSubTools(pluginItem(),
|
||||||
|
"plugin_demo",
|
||||||
|
"测试插件",
|
||||||
|
new StubPluginToolExecutor(businessResult),
|
||||||
|
new InMemoryTaskStore(),
|
||||||
|
executor);
|
||||||
|
|
||||||
|
AsyncToolTaskView view = submitAndResult(subTools);
|
||||||
|
|
||||||
|
Assert.assertEquals(AsyncToolTaskStatus.SUCCEEDED, view.getStatus());
|
||||||
|
Assert.assertEquals(businessResult, view.getResult());
|
||||||
|
} finally {
|
||||||
|
executor.shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView submitAndResult(AbstractAgentAsyncSubTools subTools) throws Exception {
|
||||||
|
AgentToolContext context = context();
|
||||||
|
AsyncToolSubmitResult submitted = subTools.submit(Map.of("keyword", "hello"), context);
|
||||||
|
waitTerminal(subTools, submitted.getTaskId(), context);
|
||||||
|
AsyncToolResultRequest request = new AsyncToolResultRequest();
|
||||||
|
request.setTaskId(submitted.getTaskId());
|
||||||
|
return subTools.result(request, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void waitTerminal(AbstractAgentAsyncSubTools subTools, String taskId, AgentToolContext context) throws Exception {
|
||||||
|
long deadline = System.currentTimeMillis() + 3000L;
|
||||||
|
AsyncToolTaskView view = observe(subTools, taskId, context);
|
||||||
|
while (!Boolean.TRUE.equals(view.getTerminal()) && System.currentTimeMillis() < deadline) {
|
||||||
|
Thread.sleep(20L);
|
||||||
|
view = observe(subTools, taskId, context);
|
||||||
|
}
|
||||||
|
Assert.assertTrue("异步任务应在测试超时前完成", Boolean.TRUE.equals(view.getTerminal()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private AsyncToolTaskView observe(AbstractAgentAsyncSubTools subTools, String taskId, AgentToolContext context) {
|
||||||
|
AsyncToolObserveRequest request = new AsyncToolObserveRequest();
|
||||||
|
request.setTaskId(taskId);
|
||||||
|
return subTools.observe(request, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolContext context() {
|
||||||
|
AgentToolContext context = new AgentToolContext();
|
||||||
|
context.setRequestId("request-1");
|
||||||
|
context.setTraceId("trace-1");
|
||||||
|
context.setSessionId("session-1");
|
||||||
|
context.setAgentId("agent-1");
|
||||||
|
context.setToolCallId("tool-call-1");
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Workflow workflow() {
|
||||||
|
Workflow workflow = new Workflow();
|
||||||
|
workflow.setId(BigInteger.valueOf(101L));
|
||||||
|
workflow.setTitle("测试工作流");
|
||||||
|
return workflow;
|
||||||
|
}
|
||||||
|
|
||||||
|
private PluginItem pluginItem() {
|
||||||
|
PluginItem pluginItem = new PluginItem();
|
||||||
|
pluginItem.setId(BigInteger.valueOf(102L));
|
||||||
|
pluginItem.setName("测试插件");
|
||||||
|
return pluginItem;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ThreadPoolTaskExecutor executor() {
|
||||||
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
|
executor.setCorePoolSize(1);
|
||||||
|
executor.setMaxPoolSize(1);
|
||||||
|
executor.setQueueCapacity(4);
|
||||||
|
executor.setThreadNamePrefix("workflow-plugin-async-test-");
|
||||||
|
executor.initialize();
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class StubWorkflowToolExecutor extends WorkflowToolExecutor {
|
||||||
|
|
||||||
|
private final Map<String, Object> businessResult;
|
||||||
|
|
||||||
|
private StubWorkflowToolExecutor(Map<String, Object> businessResult) {
|
||||||
|
super(null);
|
||||||
|
this.businessResult = businessResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentToolExecutionResult execute(Workflow workflow, Map<String, Object> arguments) {
|
||||||
|
return new AgentToolExecutionResult(businessResult, "workflow-run-1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class StubPluginToolExecutor extends PluginToolExecutor {
|
||||||
|
|
||||||
|
private final Map<String, Object> businessResult;
|
||||||
|
|
||||||
|
private StubPluginToolExecutor(Map<String, Object> businessResult) {
|
||||||
|
this.businessResult = businessResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentToolExecutionResult execute(PluginItem pluginItem, Map<String, Object> arguments) {
|
||||||
|
return new AgentToolExecutionResult(businessResult, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class InMemoryTaskStore implements AgentAsyncToolTaskStore {
|
||||||
|
|
||||||
|
private final Map<String, AgentAsyncToolTaskRecord> records = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void create(AgentAsyncToolTaskRecord record) {
|
||||||
|
record.setSessionScopedKey(key(record.getSessionId(), record.getTaskId()));
|
||||||
|
records.put(record.getSessionScopedKey(), record);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> get(String sessionId, String taskId) {
|
||||||
|
return Optional.ofNullable(records.get(key(sessionId, taskId)));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<AgentAsyncToolTaskRecord> update(String sessionId,
|
||||||
|
String taskId,
|
||||||
|
UnaryOperator<AgentAsyncToolTaskRecord> updater) {
|
||||||
|
AgentAsyncToolTaskRecord updated = records.computeIfPresent(key(sessionId, taskId),
|
||||||
|
(ignored, existing) -> updater == null ? existing : updater.apply(existing));
|
||||||
|
return Optional.ofNullable(updated);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<AgentAsyncToolTaskRecord> list(String sessionId, AsyncToolTaskStatus status) {
|
||||||
|
List<AgentAsyncToolTaskRecord> result = new ArrayList<>();
|
||||||
|
for (AgentAsyncToolTaskRecord record : records.values()) {
|
||||||
|
if (sessionId.equals(record.getSessionId()) && (status == null || status == record.getStatus())) {
|
||||||
|
result.add(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.sort(Comparator.comparing(AgentAsyncToolTaskRecord::getCreatedAt).reversed());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String key(String sessionId, String taskId) {
|
||||||
|
return sessionId + ":" + taskId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
package tech.easyflow.agent.runtime.tool;
|
||||||
|
|
||||||
|
import com.easyagents.agent.runtime.tool.AgentToolSpec;
|
||||||
|
import com.easyagents.core.model.chat.tool.Parameter;
|
||||||
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.agent.entity.Agent;
|
||||||
|
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||||
|
import tech.easyflow.agent.enums.AgentToolType;
|
||||||
|
import tech.easyflow.ai.entity.PluginItem;
|
||||||
|
import tech.easyflow.ai.entity.Workflow;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Agent 工具运行时编译测试。
|
||||||
|
*/
|
||||||
|
public class AgentToolRuntimeCompilerTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Workflow 默认按同步工具编译。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldUseSyncModeByDefault() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = compiler();
|
||||||
|
|
||||||
|
AgentToolRuntimeCompilation compilation = compiler.compile(agent(workflowBinding(null, false, "flow-sync")));
|
||||||
|
|
||||||
|
Assert.assertEquals(List.of("flow-sync"), toolNames(compilation));
|
||||||
|
Assert.assertEquals(1, compilation.getToolInvokers().size());
|
||||||
|
Assert.assertFalse(compilation.getToolSpecs().get(0).isApprovalRequired());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证非法执行模式会回退为同步工具。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldFallbackToSyncWhenExecutionModeInvalid() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = compiler();
|
||||||
|
|
||||||
|
AgentToolRuntimeCompilation compilation = compiler.compile(agent(workflowBinding("BAD", false, "flow-sync")));
|
||||||
|
|
||||||
|
Assert.assertEquals(List.of("flow-sync"), toolNames(compilation));
|
||||||
|
Assert.assertEquals(1, compilation.getToolInvokers().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Workflow 异步模式会展开为五个固定子工具。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldExpandWorkflowAsyncSubToolsAndNormalizeName() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = compiler();
|
||||||
|
|
||||||
|
AgentToolRuntimeCompilation compilation = compiler.compile(agent(workflowBinding("ASYNC", true, "flow-alpha")));
|
||||||
|
|
||||||
|
Assert.assertEquals(List.of(
|
||||||
|
"flow_alpha_submit",
|
||||||
|
"flow_alpha_observe",
|
||||||
|
"flow_alpha_result",
|
||||||
|
"flow_alpha_cancel",
|
||||||
|
"flow_alpha_list"
|
||||||
|
), toolNames(compilation));
|
||||||
|
Assert.assertEquals(5, compilation.getToolInvokers().size());
|
||||||
|
Assert.assertEquals(List.of("keyword"), compilation.getToolSpecs().get(0).getParametersSchema().get("required"));
|
||||||
|
Assert.assertTrue(compilation.getToolSpecs().get(0).isApprovalRequired());
|
||||||
|
Assert.assertEquals("确认执行?", compilation.getToolSpecs().get(0).getApprovalRequest().getApprovalPrompt());
|
||||||
|
Assert.assertFalse(compilation.getToolSpecs().get(1).isApprovalRequired());
|
||||||
|
Assert.assertEquals("flow_alpha", compilation.getToolSpecs().get(0).getMetadata().get("asyncToolName"));
|
||||||
|
Assert.assertEquals("submit", compilation.getToolSpecs().get(0).getMetadata().get("asyncToolPhase"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 Plugin 异步模式同样展开为五个固定子工具。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldExpandPluginAsyncSubTools() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = compiler();
|
||||||
|
|
||||||
|
AgentToolRuntimeCompilation compilation = compiler.compile(agent(pluginBinding("ASYNC", "plugin-tool")));
|
||||||
|
|
||||||
|
Assert.assertEquals(List.of(
|
||||||
|
"plugin_tool_submit",
|
||||||
|
"plugin_tool_observe",
|
||||||
|
"plugin_tool_result",
|
||||||
|
"plugin_tool_cancel",
|
||||||
|
"plugin_tool_list"
|
||||||
|
), toolNames(compilation));
|
||||||
|
Assert.assertEquals(5, compilation.getToolInvokers().size());
|
||||||
|
for (AgentToolSpec spec : compilation.getToolSpecs()) {
|
||||||
|
Assert.assertEquals(Boolean.TRUE, spec.getMetadata().get("asyncTool"));
|
||||||
|
Assert.assertEquals("插件工具", spec.getMetadata().get("toolDisplayName"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证异步工具名归一化后发生冲突时会在编译阶段失败。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入依赖失败时抛出
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void compileShouldRejectNormalizedAsyncToolNameCollision() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = compiler();
|
||||||
|
AgentToolBinding first = workflowBinding("ASYNC", false, "flow-alpha");
|
||||||
|
AgentToolBinding second = workflowBinding("ASYNC", false, "flow_alpha");
|
||||||
|
second.setId(BigInteger.valueOf(13L));
|
||||||
|
second.setTargetId(BigInteger.valueOf(103L));
|
||||||
|
|
||||||
|
try {
|
||||||
|
compiler.compile(agent(List.of(first, second)));
|
||||||
|
Assert.fail("异步工具名冲突时应编译失败");
|
||||||
|
} catch (BusinessException e) {
|
||||||
|
Assert.assertTrue(e.getMessage().contains("flow_alpha_submit"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolRuntimeCompiler compiler() throws Exception {
|
||||||
|
AgentToolRuntimeCompiler compiler = new AgentToolRuntimeCompiler();
|
||||||
|
setField(compiler, "objectMapper", new ObjectMapper());
|
||||||
|
setField(compiler, "workflowToolExecutor", new StubWorkflowToolExecutor());
|
||||||
|
setField(compiler, "pluginToolExecutor", new StubPluginToolExecutor());
|
||||||
|
return compiler;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Agent agent(AgentToolBinding binding) {
|
||||||
|
return agent(List.of(binding));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Agent agent(List<AgentToolBinding> bindings) {
|
||||||
|
Agent agent = new Agent();
|
||||||
|
agent.setId(BigInteger.ONE);
|
||||||
|
agent.setToolBindings(bindings);
|
||||||
|
return agent;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolBinding workflowBinding(String executionMode, boolean hitlEnabled, String toolName) {
|
||||||
|
AgentToolBinding binding = new AgentToolBinding();
|
||||||
|
binding.setId(BigInteger.valueOf(11L));
|
||||||
|
binding.setToolType(AgentToolType.WORKFLOW.name());
|
||||||
|
binding.setTargetId(BigInteger.valueOf(101L));
|
||||||
|
binding.setToolName(toolName);
|
||||||
|
binding.setEnabled(true);
|
||||||
|
binding.setHitlEnabled(hitlEnabled);
|
||||||
|
binding.setHitlConfigJson(Map.of("prompt", "确认执行?"));
|
||||||
|
binding.setOptionsJson(executionMode == null ? Map.of() : Map.of("executionMode", executionMode));
|
||||||
|
binding.setResourceSnapshot(Map.of(
|
||||||
|
"id", BigInteger.valueOf(101L),
|
||||||
|
"title", "客户检索工作流",
|
||||||
|
"description", "按关键词检索客户",
|
||||||
|
"englishName", "flow-alpha"
|
||||||
|
));
|
||||||
|
return binding;
|
||||||
|
}
|
||||||
|
|
||||||
|
private AgentToolBinding pluginBinding(String executionMode, String toolName) {
|
||||||
|
AgentToolBinding binding = new AgentToolBinding();
|
||||||
|
binding.setId(BigInteger.valueOf(12L));
|
||||||
|
binding.setToolType(AgentToolType.PLUGIN.name());
|
||||||
|
binding.setTargetId(BigInteger.valueOf(102L));
|
||||||
|
binding.setToolName(toolName);
|
||||||
|
binding.setEnabled(true);
|
||||||
|
binding.setOptionsJson(Map.of("executionMode", executionMode));
|
||||||
|
binding.setResourceSnapshot(Map.of(
|
||||||
|
"id", BigInteger.valueOf(102L),
|
||||||
|
"name", "插件工具",
|
||||||
|
"description", "调用插件",
|
||||||
|
"englishName", "plugin-tool"
|
||||||
|
));
|
||||||
|
return binding;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> toolNames(AgentToolRuntimeCompilation compilation) {
|
||||||
|
return compilation.getToolSpecs().stream().map(AgentToolSpec::getName).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
|
Field field = target.getClass().getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(target, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Tool testTool(String name, String description) {
|
||||||
|
Parameter parameter = new Parameter();
|
||||||
|
parameter.setName("keyword");
|
||||||
|
parameter.setDescription("关键词");
|
||||||
|
parameter.setType("string");
|
||||||
|
parameter.setRequired(true);
|
||||||
|
return Tool.builder()
|
||||||
|
.name(name)
|
||||||
|
.description(description)
|
||||||
|
.addParameter(parameter)
|
||||||
|
.function(arguments -> Map.of("ok", true))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private final class StubWorkflowToolExecutor extends WorkflowToolExecutor {
|
||||||
|
|
||||||
|
private StubWorkflowToolExecutor() {
|
||||||
|
super(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tool buildTool(Workflow workflow) {
|
||||||
|
return testTool(workflow.getEnglishName(), workflow.getDescription());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentToolExecutionResult execute(Workflow workflow, Map<String, Object> arguments) {
|
||||||
|
return new AgentToolExecutionResult(Map.of("ok", true), "wf-run-1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private final class StubPluginToolExecutor extends PluginToolExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Tool buildTool(PluginItem pluginItem) {
|
||||||
|
return testTool(pluginItem.getEnglishName(), pluginItem.getDescription());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AgentToolExecutionResult execute(PluginItem pluginItem, Map<String, Object> arguments) {
|
||||||
|
return new AgentToolExecutionResult(Map.of("ok", true), null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -115,6 +115,11 @@
|
|||||||
<groupId>tech.easyflow</groupId>
|
<groupId>tech.easyflow</groupId>
|
||||||
<artifactId>easyflow-common-mq</artifactId>
|
<artifactId>easyflow-common-mq</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-actuator</artifactId>
|
||||||
|
<version>${spring-boot.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.easyagents</groupId>
|
<groupId>com.easyagents</groupId>
|
||||||
@@ -126,5 +131,58 @@
|
|||||||
<version>${junit.version}</version>
|
<version>${junit.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>5.12.0</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>release-obfuscation</id>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>com.github.wvengen</groupId>
|
||||||
|
<artifactId>proguard-maven-plugin</artifactId>
|
||||||
|
<version>${proguard.maven.plugin.version}</version>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.guardsquare</groupId>
|
||||||
|
<artifactId>proguard-base</artifactId>
|
||||||
|
<version>${proguard.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>release-obfuscation</id>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>proguard</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<proguardVersion>${proguard.version}</proguardVersion>
|
||||||
|
<proguardInclude>${maven.multiModuleProjectDirectory}/config/proguard/easyflow-module-ai.pro</proguardInclude>
|
||||||
|
<mappingFileName>proguard-map-${project.artifactId}.txt</mappingFileName>
|
||||||
|
<seedFileName>proguard-seed-${project.artifactId}.txt</seedFileName>
|
||||||
|
<includeDependency>true</includeDependency>
|
||||||
|
<includeDependencyInjar>false</includeDependencyInjar>
|
||||||
|
<attach>false</attach>
|
||||||
|
<attachMap>false</attachMap>
|
||||||
|
<appendClassifier>false</appendClassifier>
|
||||||
|
<addMavenDescriptor>false</addMavenDescriptor>
|
||||||
|
<addManifest>true</addManifest>
|
||||||
|
<putLibraryJarsInTempDir>true</putLibraryJarsInTempDir>
|
||||||
|
<generateTemporaryConfigurationFile>true</generateTemporaryConfigurationFile>
|
||||||
|
<bindToMavenLogging>true</bindToMavenLogging>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
</profile>
|
||||||
|
</profiles>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -2,8 +2,18 @@ package tech.easyflow.ai.config;
|
|||||||
|
|
||||||
import org.mybatis.spring.annotation.MapperScan;
|
import org.mybatis.spring.annotation.MapperScan;
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
|
import org.springframework.context.annotation.ComponentScan;
|
||||||
|
import tech.easyflow.ai.documentimport.task.DocumentImportParseMonitorProperties;
|
||||||
|
import tech.easyflow.ai.documentimport.task.DocumentImportStatusBroadcastProperties;
|
||||||
|
|
||||||
@MapperScan("tech.easyflow.ai.mapper")
|
@MapperScan("tech.easyflow.ai.mapper")
|
||||||
|
@ComponentScan("tech.easyflow.ai")
|
||||||
|
@EnableConfigurationProperties({
|
||||||
|
DocumentImportParseMonitorProperties.class,
|
||||||
|
DocumentImportStatusBroadcastProperties.class,
|
||||||
|
RagHealthProperties.class
|
||||||
|
})
|
||||||
@AutoConfiguration
|
@AutoConfiguration
|
||||||
public class AiModuleConfig {
|
public class AiModuleConfig {
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
|
import org.springframework.boot.actuate.health.Health;
|
||||||
|
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 健康检查短缓存支持。
|
||||||
|
*/
|
||||||
|
public abstract class CachedHealthIndicatorSupport {
|
||||||
|
|
||||||
|
private final RagHealthProperties properties;
|
||||||
|
private final Clock clock;
|
||||||
|
private volatile CacheEntry cacheEntry;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建健康检查缓存支持。
|
||||||
|
*
|
||||||
|
* @param properties RAG 健康检查配置
|
||||||
|
*/
|
||||||
|
protected CachedHealthIndicatorSupport(RagHealthProperties properties) {
|
||||||
|
this(properties, Clock.systemUTC());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建健康检查缓存支持。
|
||||||
|
*
|
||||||
|
* @param properties RAG 健康检查配置
|
||||||
|
* @param clock 时钟
|
||||||
|
*/
|
||||||
|
protected CachedHealthIndicatorSupport(RagHealthProperties properties, Clock clock) {
|
||||||
|
this.properties = properties;
|
||||||
|
this.clock = clock;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行带短缓存的健康检查。
|
||||||
|
*
|
||||||
|
* @return 健康状态
|
||||||
|
*/
|
||||||
|
protected Health cachedHealth() {
|
||||||
|
long now = clock.millis();
|
||||||
|
CacheEntry current = cacheEntry;
|
||||||
|
if (current != null && current.expireAtMillis > now) {
|
||||||
|
return current.health;
|
||||||
|
}
|
||||||
|
synchronized (this) {
|
||||||
|
current = cacheEntry;
|
||||||
|
if (current != null && current.expireAtMillis > now) {
|
||||||
|
return current.health;
|
||||||
|
}
|
||||||
|
Health health = doHealthCheck();
|
||||||
|
cacheEntry = new CacheEntry(health, now + cacheTtlMillis());
|
||||||
|
return health;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行实际健康检查。
|
||||||
|
*
|
||||||
|
* @return 健康状态
|
||||||
|
*/
|
||||||
|
protected abstract Health doHealthCheck();
|
||||||
|
|
||||||
|
private long cacheTtlMillis() {
|
||||||
|
Duration cacheTtl = properties.getCacheTtl();
|
||||||
|
if (cacheTtl == null || cacheTtl.isZero() || cacheTtl.isNegative()) {
|
||||||
|
return 0L;
|
||||||
|
}
|
||||||
|
return cacheTtl.toMillis();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CacheEntry {
|
||||||
|
|
||||||
|
private final Health health;
|
||||||
|
private final long expireAtMillis;
|
||||||
|
|
||||||
|
private CacheEntry(Health health, long expireAtMillis) {
|
||||||
|
this.health = health;
|
||||||
|
this.expireAtMillis = expireAtMillis;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,143 @@
|
|||||||
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EasyFlow 业务线程池配置。
|
||||||
|
*/
|
||||||
|
@ConfigurationProperties(prefix = "easyflow.thread-pool")
|
||||||
|
public class EasyFlowThreadPoolProperties {
|
||||||
|
|
||||||
|
private Pool sse = new Pool(4, 16, 2000, 30, true);
|
||||||
|
private Pool documentImport = new Pool(2, 4, 200, 60, true);
|
||||||
|
private Pool agentAsyncTool = new Pool(2, 8, 200, 60, true);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 SSE 线程池配置。
|
||||||
|
*
|
||||||
|
* @return SSE 线程池配置
|
||||||
|
*/
|
||||||
|
public Pool getSse() {
|
||||||
|
return sse;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 SSE 线程池配置。
|
||||||
|
*
|
||||||
|
* @param sse SSE 线程池配置
|
||||||
|
*/
|
||||||
|
public void setSse(Pool sse) {
|
||||||
|
this.sse = sse;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取文档导入线程池配置。
|
||||||
|
*
|
||||||
|
* @return 文档导入线程池配置
|
||||||
|
*/
|
||||||
|
public Pool getDocumentImport() {
|
||||||
|
return documentImport;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置文档导入线程池配置。
|
||||||
|
*
|
||||||
|
* @param documentImport 文档导入线程池配置
|
||||||
|
*/
|
||||||
|
public void setDocumentImport(Pool documentImport) {
|
||||||
|
this.documentImport = documentImport;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 Agent 异步工具后台执行线程池配置。
|
||||||
|
*
|
||||||
|
* @return Agent 异步工具线程池配置
|
||||||
|
*/
|
||||||
|
public Pool getAgentAsyncTool() {
|
||||||
|
return agentAsyncTool;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 Agent 异步工具后台执行线程池配置。
|
||||||
|
*
|
||||||
|
* @param agentAsyncTool Agent 异步工具线程池配置
|
||||||
|
*/
|
||||||
|
public void setAgentAsyncTool(Pool agentAsyncTool) {
|
||||||
|
this.agentAsyncTool = agentAsyncTool;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 线程池配置项。
|
||||||
|
*/
|
||||||
|
public static class Pool {
|
||||||
|
|
||||||
|
private int coreSize;
|
||||||
|
private int maxSize;
|
||||||
|
private int queueCapacity;
|
||||||
|
private int keepAliveSeconds;
|
||||||
|
private boolean allowCoreThreadTimeout;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建默认线程池配置。
|
||||||
|
*/
|
||||||
|
public Pool() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建线程池配置。
|
||||||
|
*
|
||||||
|
* @param coreSize 核心线程数
|
||||||
|
* @param maxSize 最大线程数
|
||||||
|
* @param queueCapacity 队列容量
|
||||||
|
* @param keepAliveSeconds 空闲线程存活时间
|
||||||
|
* @param allowCoreThreadTimeout 是否允许核心线程超时
|
||||||
|
*/
|
||||||
|
public Pool(int coreSize, int maxSize, int queueCapacity, int keepAliveSeconds, boolean allowCoreThreadTimeout) {
|
||||||
|
this.coreSize = coreSize;
|
||||||
|
this.maxSize = maxSize;
|
||||||
|
this.queueCapacity = queueCapacity;
|
||||||
|
this.keepAliveSeconds = keepAliveSeconds;
|
||||||
|
this.allowCoreThreadTimeout = allowCoreThreadTimeout;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getCoreSize() {
|
||||||
|
return coreSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCoreSize(int coreSize) {
|
||||||
|
this.coreSize = coreSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getMaxSize() {
|
||||||
|
return maxSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMaxSize(int maxSize) {
|
||||||
|
this.maxSize = maxSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getQueueCapacity() {
|
||||||
|
return queueCapacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setQueueCapacity(int queueCapacity) {
|
||||||
|
this.queueCapacity = queueCapacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getKeepAliveSeconds() {
|
||||||
|
return keepAliveSeconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setKeepAliveSeconds(int keepAliveSeconds) {
|
||||||
|
this.keepAliveSeconds = keepAliveSeconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isAllowCoreThreadTimeout() {
|
||||||
|
return allowCoreThreadTimeout;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAllowCoreThreadTimeout(boolean allowCoreThreadTimeout) {
|
||||||
|
this.allowCoreThreadTimeout = allowCoreThreadTimeout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
|
import com.easyagents.engine.es.ElasticSearcher;
|
||||||
|
import com.easyagents.search.engine.service.DocumentSearcher;
|
||||||
|
import com.easyagents.store.milvus.MilvusVectorStore;
|
||||||
|
import org.springframework.boot.actuate.health.Health;
|
||||||
|
import org.springframework.boot.actuate.health.HealthIndicator;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.ai.rag.KeywordEngineType;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
|
import tech.easyflow.common.util.SpringContextUtil;
|
||||||
|
import tech.easyflow.common.util.StringUtil;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RAG 依赖中间件健康检查。
|
||||||
|
*/
|
||||||
|
public class RagHealthIndicator {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Milvus 健康检查。
|
||||||
|
*/
|
||||||
|
@Component("ragMilvusHealthIndicator")
|
||||||
|
public static class RagMilvusHealthIndicator extends CachedHealthIndicatorSupport implements HealthIndicator {
|
||||||
|
|
||||||
|
private final AiMilvusConfig aiMilvusConfig;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Milvus 健康检查器。
|
||||||
|
*
|
||||||
|
* @param aiMilvusConfig Milvus 配置
|
||||||
|
* @param healthProperties RAG 健康检查配置
|
||||||
|
*/
|
||||||
|
public RagMilvusHealthIndicator(AiMilvusConfig aiMilvusConfig, RagHealthProperties healthProperties) {
|
||||||
|
super(healthProperties);
|
||||||
|
this.aiMilvusConfig = aiMilvusConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查 Milvus 是否可连接。
|
||||||
|
*
|
||||||
|
* @return 健康状态
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Health health() {
|
||||||
|
return cachedHealth();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Health doHealthCheck() {
|
||||||
|
MilvusVectorStore vectorStore = null;
|
||||||
|
try {
|
||||||
|
vectorStore = new MilvusVectorStore(
|
||||||
|
aiMilvusConfig.copyForCollection("__rag_health_probe__")
|
||||||
|
);
|
||||||
|
if (vectorStore.checkAvailable()) {
|
||||||
|
return Health.up().withDetail("uri", aiMilvusConfig.getUri()).build();
|
||||||
|
}
|
||||||
|
return Health.down().withDetail("uri", aiMilvusConfig.getUri()).build();
|
||||||
|
} catch (Exception e) {
|
||||||
|
return Health.down(e).withDetail("uri", aiMilvusConfig.getUri()).build();
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(vectorStore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关键词检索健康检查。
|
||||||
|
*/
|
||||||
|
@Component("ragKeywordSearchHealthIndicator")
|
||||||
|
public static class RagKeywordSearchHealthIndicator extends CachedHealthIndicatorSupport implements HealthIndicator {
|
||||||
|
|
||||||
|
private final SearcherFactory searcherFactory;
|
||||||
|
private final AiLuceneConfig aiLuceneConfig;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建关键词检索健康检查器。
|
||||||
|
*
|
||||||
|
* @param searcherFactory 检索器工厂
|
||||||
|
* @param aiLuceneConfig Lucene 配置
|
||||||
|
* @param healthProperties RAG 健康检查配置
|
||||||
|
*/
|
||||||
|
public RagKeywordSearchHealthIndicator(SearcherFactory searcherFactory,
|
||||||
|
AiLuceneConfig aiLuceneConfig,
|
||||||
|
RagHealthProperties healthProperties) {
|
||||||
|
super(healthProperties);
|
||||||
|
this.searcherFactory = searcherFactory;
|
||||||
|
this.aiLuceneConfig = aiLuceneConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查当前关键词检索引擎是否可用。
|
||||||
|
*
|
||||||
|
* @return 健康状态
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Health health() {
|
||||||
|
return cachedHealth();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Health doHealthCheck() {
|
||||||
|
KeywordEngineType engineType = KeywordEngineType.from(SpringContextUtil.getProperty("rag.engine", "ES"));
|
||||||
|
if (engineType == KeywordEngineType.LUCENE) {
|
||||||
|
return checkLuceneDirectory(engineType);
|
||||||
|
}
|
||||||
|
DocumentSearcher searcher = searcherFactory.getSearcher();
|
||||||
|
if (searcher instanceof ElasticSearcher elasticSearcher && elasticSearcher.checkAvailable()) {
|
||||||
|
return Health.up().withDetail("engine", engineType.name()).build();
|
||||||
|
}
|
||||||
|
return Health.down().withDetail("engine", engineType.name()).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Health checkLuceneDirectory(KeywordEngineType engineType) {
|
||||||
|
String indexDirPath = aiLuceneConfig.getIndexDirPath();
|
||||||
|
if (StringUtil.noText(indexDirPath)) {
|
||||||
|
return Health.down()
|
||||||
|
.withDetail("engine", engineType.name())
|
||||||
|
.withDetail("reason", "Lucene 索引目录未配置")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
File indexDir = new File(indexDirPath);
|
||||||
|
if (indexDir.exists() && indexDir.isDirectory() && indexDir.canRead() && indexDir.canWrite()) {
|
||||||
|
return Health.up()
|
||||||
|
.withDetail("engine", engineType.name())
|
||||||
|
.withDetail("indexDir", indexDirPath)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
return Health.down()
|
||||||
|
.withDetail("engine", engineType.name())
|
||||||
|
.withDetail("indexDir", indexDirPath)
|
||||||
|
.withDetail("reason", "Lucene 索引目录不可读写")
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RAG 健康检查配置。
|
||||||
|
*/
|
||||||
|
@ConfigurationProperties(prefix = "easyflow.ai.rag.health")
|
||||||
|
public class RagHealthProperties {
|
||||||
|
|
||||||
|
private Duration cacheTtl = Duration.ofSeconds(5);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取健康检查结果缓存时间。
|
||||||
|
*
|
||||||
|
* @return 缓存时间
|
||||||
|
*/
|
||||||
|
public Duration getCacheTtl() {
|
||||||
|
return cacheTtl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置健康检查结果缓存时间。
|
||||||
|
*
|
||||||
|
* @param cacheTtl 缓存时间
|
||||||
|
*/
|
||||||
|
public void setCacheTtl(Duration cacheTtl) {
|
||||||
|
this.cacheTtl = cacheTtl == null || cacheTtl.isNegative() ? Duration.ofSeconds(5) : cacheTtl;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,8 +1,5 @@
|
|||||||
package tech.easyflow.ai.config;
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
import com.easyagents.engine.es.ElasticSearcher;
|
|
||||||
import com.easyagents.search.engine.service.DocumentSearcher;
|
|
||||||
import com.easyagents.store.milvus.MilvusVectorStore;
|
|
||||||
import org.springframework.beans.factory.SmartInitializingSingleton;
|
import org.springframework.beans.factory.SmartInitializingSingleton;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import tech.easyflow.ai.rag.KeywordEngineType;
|
import tech.easyflow.ai.rag.KeywordEngineType;
|
||||||
@@ -16,9 +13,6 @@ import java.io.File;
|
|||||||
@Component
|
@Component
|
||||||
public class RagInfrastructureValidator implements SmartInitializingSingleton {
|
public class RagInfrastructureValidator implements SmartInitializingSingleton {
|
||||||
|
|
||||||
private static final int STARTUP_CHECK_RETRY_TIMES = 10;
|
|
||||||
private static final long STARTUP_CHECK_RETRY_INTERVAL_MS = 1000L;
|
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiMilvusConfig aiMilvusConfig;
|
private AiMilvusConfig aiMilvusConfig;
|
||||||
|
|
||||||
@@ -26,31 +20,21 @@ public class RagInfrastructureValidator implements SmartInitializingSingleton {
|
|||||||
private AiLuceneConfig aiLuceneConfig;
|
private AiLuceneConfig aiLuceneConfig;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private SearcherFactory searcherFactory;
|
private AiEsConfig aiEsConfig;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 校验 RAG 基础配置。
|
||||||
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void afterSingletonsInstantiated() {
|
public void afterSingletonsInstantiated() {
|
||||||
validateMilvus();
|
validateMilvusConfig();
|
||||||
validateKeywordSearcher();
|
validateKeywordSearcher();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validateMilvus() {
|
private void validateMilvusConfig() {
|
||||||
Exception lastException = null;
|
if (StringUtil.noText(aiMilvusConfig.getUri())) {
|
||||||
for (int i = 0; i < STARTUP_CHECK_RETRY_TIMES; i++) {
|
throw new BusinessException("Milvus uri 未配置,请检查 rag.milvus.uri");
|
||||||
try {
|
|
||||||
MilvusVectorStore vectorStore = new MilvusVectorStore(aiMilvusConfig.copyForCollection("__rag_boot_probe__"));
|
|
||||||
if (vectorStore.checkAvailable()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
lastException = e;
|
|
||||||
}
|
|
||||||
sleepBeforeRetry();
|
|
||||||
}
|
}
|
||||||
if (lastException != null) {
|
|
||||||
throw new BusinessException("Milvus 服务不可用,项目启动失败,请检查 rag.milvus 配置与服务状态: " + lastException.getMessage());
|
|
||||||
}
|
|
||||||
throw new BusinessException("Milvus 服务不可用,项目启动失败,请检查 rag.milvus 配置与服务状态");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validateKeywordSearcher() {
|
private void validateKeywordSearcher() {
|
||||||
@@ -61,21 +45,12 @@ public class RagInfrastructureValidator implements SmartInitializingSingleton {
|
|||||||
validateLuceneDirectory();
|
validateLuceneDirectory();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (StringUtil.noText(aiEsConfig.getHost())) {
|
||||||
DocumentSearcher searcher = searcherFactory.getSearcher();
|
throw new BusinessException("ES 地址未配置,请检查 rag.searcher.elastic.host");
|
||||||
if (!(searcher instanceof ElasticSearcher) || !checkElasticAvailable((ElasticSearcher) searcher)) {
|
|
||||||
throw new BusinessException("ES 服务不可用,项目启动失败,请检查 rag.engine 与 rag.searcher.elastic 配置");
|
|
||||||
}
|
}
|
||||||
}
|
if (StringUtil.noText(aiEsConfig.getIndexName())) {
|
||||||
|
throw new BusinessException("ES 索引未配置,请检查 rag.searcher.elastic.indexName");
|
||||||
private boolean checkElasticAvailable(ElasticSearcher elasticSearcher) {
|
|
||||||
for (int i = 0; i < STARTUP_CHECK_RETRY_TIMES; i++) {
|
|
||||||
if (elasticSearcher.checkAvailable()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
sleepBeforeRetry();
|
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validateLuceneDirectory() {
|
private void validateLuceneDirectory() {
|
||||||
@@ -92,12 +67,4 @@ public class RagInfrastructureValidator implements SmartInitializingSingleton {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void sleepBeforeRetry() {
|
|
||||||
try {
|
|
||||||
Thread.sleep(STARTUP_CHECK_RETRY_INTERVAL_MS);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
Thread.currentThread().interrupt();
|
|
||||||
throw new BusinessException("中间件启动校验被中断");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,28 @@ package tech.easyflow.ai.config;
|
|||||||
|
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
|
@EnableConfigurationProperties(EasyFlowThreadPoolProperties.class)
|
||||||
public class ThreadPoolConfig {
|
public class ThreadPoolConfig {
|
||||||
private static final Logger log = LoggerFactory.getLogger(ThreadPoolConfig.class);
|
private static final Logger log = LoggerFactory.getLogger(ThreadPoolConfig.class);
|
||||||
|
|
||||||
|
private final EasyFlowThreadPoolProperties properties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建线程池配置。
|
||||||
|
*
|
||||||
|
* @param properties 线程池配置属性
|
||||||
|
*/
|
||||||
|
public ThreadPoolConfig(EasyFlowThreadPoolProperties properties) {
|
||||||
|
this.properties = properties;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建 SSE 消息发送线程池。
|
* 创建 SSE 消息发送线程池。
|
||||||
*
|
*
|
||||||
@@ -19,11 +32,12 @@ public class ThreadPoolConfig {
|
|||||||
@Bean(name = "sseThreadPool")
|
@Bean(name = "sseThreadPool")
|
||||||
public ThreadPoolTaskExecutor sseThreadPool() {
|
public ThreadPoolTaskExecutor sseThreadPool() {
|
||||||
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
int cpuCoreNum = Runtime.getRuntime().availableProcessors(); // 获取CPU核心数(4核返回4)
|
EasyFlowThreadPoolProperties.Pool pool = properties.getSse();
|
||||||
executor.setCorePoolSize(cpuCoreNum * 2); // 核心线程数
|
executor.setCorePoolSize(pool.getCoreSize());
|
||||||
executor.setMaxPoolSize(cpuCoreNum * 10); // 最大线程数(峰值时扩容,避免线程过多导致上下文切换)
|
executor.setMaxPoolSize(pool.getMaxSize());
|
||||||
executor.setQueueCapacity(8000); // 任务队列容量
|
executor.setQueueCapacity(pool.getQueueCapacity());
|
||||||
executor.setKeepAliveSeconds(30); // 空闲线程存活时间:30秒(非核心线程空闲后销毁,节省资源)
|
executor.setKeepAliveSeconds(pool.getKeepAliveSeconds());
|
||||||
|
executor.setAllowCoreThreadTimeOut(pool.isAllowCoreThreadTimeout());
|
||||||
executor.setThreadNamePrefix("sse-sender-");
|
executor.setThreadNamePrefix("sse-sender-");
|
||||||
|
|
||||||
// 拒绝策略
|
// 拒绝策略
|
||||||
@@ -47,11 +61,12 @@ public class ThreadPoolConfig {
|
|||||||
@Bean(name = "documentImportTaskExecutor")
|
@Bean(name = "documentImportTaskExecutor")
|
||||||
public ThreadPoolTaskExecutor documentImportTaskExecutor() {
|
public ThreadPoolTaskExecutor documentImportTaskExecutor() {
|
||||||
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
int cpuCoreNum = Runtime.getRuntime().availableProcessors();
|
EasyFlowThreadPoolProperties.Pool pool = properties.getDocumentImport();
|
||||||
executor.setCorePoolSize(Math.max(2, cpuCoreNum));
|
executor.setCorePoolSize(pool.getCoreSize());
|
||||||
executor.setMaxPoolSize(Math.max(4, cpuCoreNum * 2));
|
executor.setMaxPoolSize(pool.getMaxSize());
|
||||||
executor.setQueueCapacity(200);
|
executor.setQueueCapacity(pool.getQueueCapacity());
|
||||||
executor.setKeepAliveSeconds(60);
|
executor.setKeepAliveSeconds(pool.getKeepAliveSeconds());
|
||||||
|
executor.setAllowCoreThreadTimeOut(pool.isAllowCoreThreadTimeout());
|
||||||
executor.setThreadNamePrefix("document-import-");
|
executor.setThreadNamePrefix("document-import-");
|
||||||
executor.setRejectedExecutionHandler((runnable, executorService) -> {
|
executor.setRejectedExecutionHandler((runnable, executorService) -> {
|
||||||
log.error("文档导入线程池过载!核心线程数:{},最大线程数:{},队列任务数:{}",
|
log.error("文档导入线程池过载!核心线程数:{},最大线程数:{},队列任务数:{}",
|
||||||
@@ -63,4 +78,30 @@ public class ThreadPoolConfig {
|
|||||||
executor.initialize();
|
executor.initialize();
|
||||||
return executor;
|
return executor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Agent 异步工具后台执行线程池。
|
||||||
|
*
|
||||||
|
* @return Agent 异步工具执行线程池
|
||||||
|
*/
|
||||||
|
@Bean(name = "agentAsyncToolExecutor")
|
||||||
|
public ThreadPoolTaskExecutor agentAsyncToolExecutor() {
|
||||||
|
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||||
|
EasyFlowThreadPoolProperties.Pool pool = properties.getAgentAsyncTool();
|
||||||
|
executor.setCorePoolSize(pool.getCoreSize());
|
||||||
|
executor.setMaxPoolSize(pool.getMaxSize());
|
||||||
|
executor.setQueueCapacity(pool.getQueueCapacity());
|
||||||
|
executor.setKeepAliveSeconds(pool.getKeepAliveSeconds());
|
||||||
|
executor.setAllowCoreThreadTimeOut(pool.isAllowCoreThreadTimeout());
|
||||||
|
executor.setThreadNamePrefix("agent-async-tool-");
|
||||||
|
executor.setRejectedExecutionHandler((runnable, executorService) -> {
|
||||||
|
log.error("Agent异步工具线程池过载!核心线程数:{},最大线程数:{},队列任务数:{}",
|
||||||
|
executorService.getCorePoolSize(),
|
||||||
|
executorService.getMaximumPoolSize(),
|
||||||
|
executorService.getQueue().size());
|
||||||
|
throw new BusinessException("Agent 异步工具任务繁忙,请稍后重试");
|
||||||
|
});
|
||||||
|
executor.initialize();
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package tech.easyflow.ai.documentimport.task;
|
|||||||
|
|
||||||
import org.springframework.scheduling.annotation.Scheduled;
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
import tech.easyflow.common.cache.DistributedScheduledLock;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库文档解析任务收敛器。
|
* 知识库文档解析任务收敛器。
|
||||||
@@ -24,9 +25,10 @@ public class DocumentImportParseMonitor {
|
|||||||
* 定时收敛运行中的桥接解析任务状态。
|
* 定时收敛运行中的桥接解析任务状态。
|
||||||
*/
|
*/
|
||||||
@Scheduled(
|
@Scheduled(
|
||||||
fixedDelayString = "${easyflow.ai.document-import.parse-monitor.fixed-delay:3000}",
|
fixedDelayString = "${easyflow.ai.document-import.parse-monitor.fixed-delay:10000}",
|
||||||
initialDelayString = "${easyflow.ai.document-import.parse-monitor.initial-delay:5000}"
|
initialDelayString = "${easyflow.ai.document-import.parse-monitor.initial-delay:10000}"
|
||||||
)
|
)
|
||||||
|
@DistributedScheduledLock(key = "easyflow:schedule:document-import:parse-monitor", leaseSeconds = 300L)
|
||||||
public void reconcileRunningParseTasks() {
|
public void reconcileRunningParseTasks() {
|
||||||
appService.monitorRunningParseTasks();
|
appService.monitorRunningParseTasks();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package tech.easyflow.ai.documentimport.task;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档解析任务监控配置。
|
||||||
|
*/
|
||||||
|
@ConfigurationProperties(prefix = "easyflow.ai.document-import.parse-monitor")
|
||||||
|
public class DocumentImportParseMonitorProperties {
|
||||||
|
|
||||||
|
private int batchSize = 10;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取单次监控批量。
|
||||||
|
*
|
||||||
|
* @return 单次监控批量
|
||||||
|
*/
|
||||||
|
public int getBatchSize() {
|
||||||
|
return batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置单次监控批量。
|
||||||
|
*
|
||||||
|
* @param batchSize 单次监控批量
|
||||||
|
*/
|
||||||
|
public void setBatchSize(int batchSize) {
|
||||||
|
this.batchSize = batchSize <= 0 ? 10 : batchSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package tech.easyflow.ai.documentimport.task;
|
||||||
|
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.data.redis.connection.Message;
|
||||||
|
import org.springframework.data.redis.connection.MessageListener;
|
||||||
|
import org.springframework.data.redis.connection.RedisConnectionFactory;
|
||||||
|
import org.springframework.data.redis.listener.ChannelTopic;
|
||||||
|
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
|
||||||
|
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档导入状态 Redis 广播配置。
|
||||||
|
*/
|
||||||
|
@Configuration
|
||||||
|
public class DocumentImportStatusBroadcastConfig {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportStatusBroadcastConfig.class);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建文档导入状态广播监听容器。
|
||||||
|
*
|
||||||
|
* @param connectionFactory Redis 连接工厂
|
||||||
|
* @param streamService 文档导入状态流服务
|
||||||
|
* @param properties 文档导入监控配置
|
||||||
|
* @return Redis 消息监听容器
|
||||||
|
*/
|
||||||
|
@Bean
|
||||||
|
public RedisMessageListenerContainer documentImportStatusListenerContainer(
|
||||||
|
RedisConnectionFactory connectionFactory,
|
||||||
|
DocumentImportTaskStatusStreamService streamService,
|
||||||
|
DocumentImportStatusBroadcastProperties properties
|
||||||
|
) {
|
||||||
|
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
|
||||||
|
container.setConnectionFactory(connectionFactory);
|
||||||
|
container.addMessageListener(
|
||||||
|
new DocumentImportStatusMessageListener(streamService),
|
||||||
|
new ChannelTopic(properties.getStatusBroadcastChannel())
|
||||||
|
);
|
||||||
|
return container;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档导入状态广播监听器。
|
||||||
|
*/
|
||||||
|
private static final class DocumentImportStatusMessageListener implements MessageListener {
|
||||||
|
|
||||||
|
private final DocumentImportTaskStatusStreamService streamService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建监听器。
|
||||||
|
*
|
||||||
|
* @param streamService 文档导入状态流服务
|
||||||
|
*/
|
||||||
|
private DocumentImportStatusMessageListener(DocumentImportTaskStatusStreamService streamService) {
|
||||||
|
this.streamService = streamService;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理 Redis 广播消息。
|
||||||
|
*
|
||||||
|
* @param message 消息
|
||||||
|
* @param pattern 订阅模式
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void onMessage(Message message, byte[] pattern) {
|
||||||
|
String payload = new String(message.getBody(), StandardCharsets.UTF_8);
|
||||||
|
try {
|
||||||
|
streamService.publishLocal(new BigInteger(payload));
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.warn("处理文档导入状态广播失败: payload={}", payload, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
package tech.easyflow.ai.documentimport.task;
|
||||||
|
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档导入状态广播配置。
|
||||||
|
*/
|
||||||
|
@ConfigurationProperties(prefix = "easyflow.ai.document-import")
|
||||||
|
public class DocumentImportStatusBroadcastProperties {
|
||||||
|
|
||||||
|
private String statusBroadcastChannel = "easyflow:document-import:status";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取文档导入状态广播通道。
|
||||||
|
*
|
||||||
|
* @return Redis 广播通道
|
||||||
|
*/
|
||||||
|
public String getStatusBroadcastChannel() {
|
||||||
|
return statusBroadcastChannel;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置文档导入状态广播通道。
|
||||||
|
*
|
||||||
|
* @param statusBroadcastChannel Redis 广播通道
|
||||||
|
*/
|
||||||
|
public void setStatusBroadcastChannel(String statusBroadcastChannel) {
|
||||||
|
if (statusBroadcastChannel == null || statusBroadcastChannel.trim().isEmpty()) {
|
||||||
|
this.statusBroadcastChannel = "easyflow:document-import:status";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.statusBroadcastChannel = statusBroadcastChannel.trim();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,17 +1,22 @@
|
|||||||
package tech.easyflow.ai.documentimport.task;
|
package tech.easyflow.ai.documentimport.task;
|
||||||
|
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.support.TransactionSynchronization;
|
import org.springframework.transaction.support.TransactionSynchronization;
|
||||||
import org.springframework.transaction.support.TransactionSynchronizationManager;
|
import org.springframework.transaction.support.TransactionSynchronizationManager;
|
||||||
|
import org.springframework.web.context.request.async.AsyncRequestNotUsableException;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import tech.easyflow.ai.documentimport.DocumentImportKeys;
|
import tech.easyflow.ai.documentimport.DocumentImportKeys;
|
||||||
import tech.easyflow.ai.entity.Document;
|
import tech.easyflow.ai.entity.Document;
|
||||||
import tech.easyflow.ai.mapper.DocumentMapper;
|
import tech.easyflow.ai.mapper.DocumentMapper;
|
||||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
|
import java.io.IOException;
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
@@ -28,6 +33,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
@Service
|
@Service
|
||||||
public class DocumentImportTaskStatusStreamService {
|
public class DocumentImportTaskStatusStreamService {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportTaskStatusStreamService.class);
|
||||||
private static final long SSE_TIMEOUT_MS = Duration.ofMinutes(30).toMillis();
|
private static final long SSE_TIMEOUT_MS = Duration.ofMinutes(30).toMillis();
|
||||||
|
|
||||||
private final Map<String, Set<SseEmitter>> knowledgeEmitters = new ConcurrentHashMap<String, Set<SseEmitter>>();
|
private final Map<String, Set<SseEmitter>> knowledgeEmitters = new ConcurrentHashMap<String, Set<SseEmitter>>();
|
||||||
@@ -38,6 +44,12 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
@Resource(name = "sseThreadPool")
|
@Resource(name = "sseThreadPool")
|
||||||
private ThreadPoolTaskExecutor sseThreadPool;
|
private ThreadPoolTaskExecutor sseThreadPool;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private StringRedisTemplate stringRedisTemplate;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private DocumentImportStatusBroadcastProperties statusBroadcastProperties;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 订阅知识库文档任务状态流。
|
* 订阅知识库文档任务状态流。
|
||||||
*
|
*
|
||||||
@@ -70,7 +82,7 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
if (documentId == null) {
|
if (documentId == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Runnable publishAction = () -> publishNow(documentId);
|
Runnable publishAction = () -> publishStatusChange(documentId);
|
||||||
if (TransactionSynchronizationManager.isSynchronizationActive()
|
if (TransactionSynchronizationManager.isSynchronizationActive()
|
||||||
&& TransactionSynchronizationManager.isActualTransactionActive()) {
|
&& TransactionSynchronizationManager.isActualTransactionActive()) {
|
||||||
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
|
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
|
||||||
@@ -84,7 +96,22 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
publishAction.run();
|
publishAction.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void publishNow(BigInteger documentId) {
|
/**
|
||||||
|
* 处理 Redis 广播收到的文档状态变更。
|
||||||
|
*
|
||||||
|
* @param documentId 文档 ID
|
||||||
|
*/
|
||||||
|
public void publishLocal(BigInteger documentId) {
|
||||||
|
publishNow(documentId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void publishStatusChange(BigInteger documentId) {
|
||||||
|
// 先推送本机连接,降低单机部署和广播链路延迟。
|
||||||
|
publishNow(documentId);
|
||||||
|
stringRedisTemplate.convertAndSend(statusBroadcastProperties.getStatusBroadcastChannel(), documentId.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
void publishNow(BigInteger documentId) {
|
||||||
Document document = documentMapper.selectOneById(documentId);
|
Document document = documentMapper.selectOneById(documentId);
|
||||||
if (document == null || document.getCollectionId() == null) {
|
if (document == null || document.getCollectionId() == null) {
|
||||||
return;
|
return;
|
||||||
@@ -134,6 +161,9 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
|
|
||||||
private void sendAsync(String topicKey, SseEmitter emitter, String eventName, Map<String, Object> payload) {
|
private void sendAsync(String topicKey, SseEmitter emitter, String eventName, Map<String, Object> payload) {
|
||||||
sseThreadPool.execute(() -> {
|
sseThreadPool.execute(() -> {
|
||||||
|
if (!isEmitterRegistered(topicKey, emitter)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
emitter.send(
|
emitter.send(
|
||||||
SseEmitter.event()
|
SseEmitter.event()
|
||||||
@@ -142,14 +172,29 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
);
|
);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
removeEmitter(topicKey, emitter);
|
removeEmitter(topicKey, emitter);
|
||||||
try {
|
if (isClientDisconnected(e)) {
|
||||||
emitter.completeWithError(e);
|
LOG.debug("文档导入状态流客户端已断开: topicKey={}, eventName={}, message={}",
|
||||||
} catch (Exception ignored) {
|
topicKey, eventName, e.getMessage());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
LOG.warn("文档导入状态流推送失败: topicKey={}, eventName={}", topicKey, eventName, e);
|
||||||
|
completeQuietly(emitter);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断指定 SSE 连接是否仍注册在主题下,避免已清理连接继续被异步任务写入。
|
||||||
|
*
|
||||||
|
* @param topicKey 主题键
|
||||||
|
* @param emitter SSE 连接
|
||||||
|
* @return 是否仍处于注册状态
|
||||||
|
*/
|
||||||
|
private boolean isEmitterRegistered(String topicKey, SseEmitter emitter) {
|
||||||
|
Set<SseEmitter> emitters = knowledgeEmitters.get(topicKey);
|
||||||
|
return emitters != null && emitters.contains(emitter);
|
||||||
|
}
|
||||||
|
|
||||||
private void removeEmitter(String topicKey, SseEmitter emitter) {
|
private void removeEmitter(String topicKey, SseEmitter emitter) {
|
||||||
Set<SseEmitter> emitters = knowledgeEmitters.get(topicKey);
|
Set<SseEmitter> emitters = knowledgeEmitters.get(topicKey);
|
||||||
if (emitters == null) {
|
if (emitters == null) {
|
||||||
@@ -161,6 +206,46 @@ public class DocumentImportTaskStatusStreamService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断异常是否由客户端断开 SSE 连接导致。
|
||||||
|
*
|
||||||
|
* @param throwable 异常
|
||||||
|
* @return 是否为客户端断连
|
||||||
|
*/
|
||||||
|
private boolean isClientDisconnected(Throwable throwable) {
|
||||||
|
Throwable current = throwable;
|
||||||
|
while (current != null) {
|
||||||
|
if (current instanceof AsyncRequestNotUsableException || current instanceof IOException) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
String message = current.getMessage();
|
||||||
|
if (message != null) {
|
||||||
|
String lowerMessage = message.toLowerCase();
|
||||||
|
if (lowerMessage.contains("broken pipe")
|
||||||
|
|| lowerMessage.contains("connection reset")
|
||||||
|
|| lowerMessage.contains("response not usable")
|
||||||
|
|| lowerMessage.contains("client abort")) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current = current.getCause();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 安静关闭 SSE 连接。
|
||||||
|
*
|
||||||
|
* @param emitter SSE 连接
|
||||||
|
*/
|
||||||
|
private void completeQuietly(SseEmitter emitter) {
|
||||||
|
try {
|
||||||
|
emitter.complete();
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.debug("关闭文档导入状态流失败: message={}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private String toTopicKey(BigInteger knowledgeId) {
|
private String toTopicKey(BigInteger knowledgeId) {
|
||||||
return String.valueOf(knowledgeId);
|
return String.valueOf(knowledgeId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ import tech.easyflow.ai.service.DocumentChunkService;
|
|||||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
import tech.easyflow.ai.service.DocumentImportTaskService;
|
import tech.easyflow.ai.service.DocumentImportTaskService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.common.domain.Result;
|
import tech.easyflow.common.domain.Result;
|
||||||
import tech.easyflow.common.filestorage.FileStorageService;
|
import tech.easyflow.common.filestorage.FileStorageService;
|
||||||
import tech.easyflow.common.util.FileUtil;
|
import tech.easyflow.common.util.FileUtil;
|
||||||
@@ -92,7 +93,6 @@ import java.util.regex.Pattern;
|
|||||||
public class KnowledgeDocumentImportTaskAppService {
|
public class KnowledgeDocumentImportTaskAppService {
|
||||||
|
|
||||||
private static final Logger LOG = LoggerFactory.getLogger(KnowledgeDocumentImportTaskAppService.class);
|
private static final Logger LOG = LoggerFactory.getLogger(KnowledgeDocumentImportTaskAppService.class);
|
||||||
private static final int PARSE_MONITOR_BATCH_SIZE = 20;
|
|
||||||
private static final int INDEX_BATCH_SIZE = 20;
|
private static final int INDEX_BATCH_SIZE = 20;
|
||||||
private static final String SOURCE_RANGES_KEY = "sourceRanges";
|
private static final String SOURCE_RANGES_KEY = "sourceRanges";
|
||||||
private static final String KNOWLEDGE_PARSE_IMAGE_CATEGORY = "knowledge-parse";
|
private static final String KNOWLEDGE_PARSE_IMAGE_CATEGORY = "knowledge-parse";
|
||||||
@@ -122,6 +122,9 @@ public class KnowledgeDocumentImportTaskAppService {
|
|||||||
@Resource
|
@Resource
|
||||||
private DocumentImportTaskService documentImportTaskService;
|
private DocumentImportTaskService documentImportTaskService;
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private DocumentImportParseMonitorProperties parseMonitorProperties;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private DocumentImportPreviewService documentImportPreviewService;
|
private DocumentImportPreviewService documentImportPreviewService;
|
||||||
|
|
||||||
@@ -403,7 +406,7 @@ public class KnowledgeDocumentImportTaskAppService {
|
|||||||
.eq(DocumentImportTask::getPhase, DocumentImportTaskPhase.PARSE.name())
|
.eq(DocumentImportTask::getPhase, DocumentImportTaskPhase.PARSE.name())
|
||||||
.eq(DocumentImportTask::getStatus, DocumentImportTaskStatus.RUNNING.name())
|
.eq(DocumentImportTask::getStatus, DocumentImportTaskStatus.RUNNING.name())
|
||||||
.orderBy(DocumentImportTask::getModified, true)
|
.orderBy(DocumentImportTask::getModified, true)
|
||||||
.limit(PARSE_MONITOR_BATCH_SIZE);
|
.limit(parseMonitorProperties.getBatchSize());
|
||||||
List<DocumentImportTask> runningTasks = documentImportTaskService.list(queryWrapper);
|
List<DocumentImportTask> runningTasks = documentImportTaskService.list(queryWrapper);
|
||||||
if (runningTasks == null || runningTasks.isEmpty()) {
|
if (runningTasks == null || runningTasks.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
@@ -516,6 +519,8 @@ public class KnowledgeDocumentImportTaskAppService {
|
|||||||
rollbackStoredChunks(taskId, document.getId(), storeContext, storedChunks);
|
rollbackStoredChunks(taskId, document.getId(), storeContext, storedChunks);
|
||||||
}
|
}
|
||||||
markIndexFailed(task, document, truncateError(e.getMessage()));
|
markIndexFailed(task, document, truncateError(e.getMessage()));
|
||||||
|
} finally {
|
||||||
|
closeStoreContext(storeContext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2123,26 +2128,31 @@ public class KnowledgeDocumentImportTaskAppService {
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
throw new BusinessException("向量数据库配置错误");
|
throw new BusinessException("向量数据库配置错误");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
throw new BusinessException("该知识库未配置向量模型");
|
if (model == null) {
|
||||||
}
|
throw new BusinessException("该知识库未配置向量模型");
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
}
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
||||||
embeddingOptions.setModel(model.getModelName());
|
embeddingOptions.setModel(model.getModelName());
|
||||||
embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel());
|
embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel());
|
||||||
options.setEmbeddingOptions(embeddingOptions);
|
options.setEmbeddingOptions(embeddingOptions);
|
||||||
options.setIndexName(options.getCollectionName());
|
options.setIndexName(options.getCollectionName());
|
||||||
return new StoreExecutionContext(
|
return new StoreExecutionContext(
|
||||||
knowledge,
|
knowledge,
|
||||||
embeddingModel,
|
embeddingModel,
|
||||||
documentStore,
|
documentStore,
|
||||||
options,
|
options,
|
||||||
searcherFactory.getSearcher()
|
searcherFactory.getSearcher()
|
||||||
);
|
);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void storeDocumentChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
|
private void storeDocumentChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
|
||||||
@@ -2221,6 +2231,13 @@ public class KnowledgeDocumentImportTaskAppService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void closeStoreContext(StoreExecutionContext storeContext) {
|
||||||
|
if (storeContext == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(storeContext.documentStore);
|
||||||
|
}
|
||||||
|
|
||||||
private void clearPersistedChunks(BigInteger documentId) {
|
private void clearPersistedChunks(BigInteger documentId) {
|
||||||
if (documentId == null) {
|
if (documentId == null) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -11,8 +11,12 @@ import tech.easyflow.common.util.StringUtil;
|
|||||||
|
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated 该类仅用于 Bot 旧版 function/tool 链路的 MCP 工具调用。
|
||||||
|
* 后续请迁移到 agent-runtime 的 MCP 工具适配链路。
|
||||||
|
*/
|
||||||
|
@Deprecated(since = "0.4", forRemoval = false)
|
||||||
public class McpTool extends BaseTool {
|
public class McpTool extends BaseTool {
|
||||||
private BigInteger mcpId;
|
private BigInteger mcpId;
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,18 @@ public class McpBase extends DateEntity implements Serializable {
|
|||||||
@Column(comment = "完整MCP配置JSON")
|
@Column(comment = "完整MCP配置JSON")
|
||||||
private String configJson;
|
private String configJson;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MCP连接方式
|
||||||
|
*/
|
||||||
|
@Column(comment = "MCP连接方式")
|
||||||
|
private String transportType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否启用工具调用审批
|
||||||
|
*/
|
||||||
|
@Column(comment = "是否启用工具调用审批")
|
||||||
|
private Boolean approvalRequired;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 部门ID
|
* 部门ID
|
||||||
*/
|
*/
|
||||||
@@ -111,6 +123,22 @@ public class McpBase extends DateEntity implements Serializable {
|
|||||||
this.configJson = configJson;
|
this.configJson = configJson;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getTransportType() {
|
||||||
|
return transportType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTransportType(String transportType) {
|
||||||
|
this.transportType = transportType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Boolean getApprovalRequired() {
|
||||||
|
return approvalRequired;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setApprovalRequired(Boolean approvalRequired) {
|
||||||
|
this.approvalRequired = approvalRequired;
|
||||||
|
}
|
||||||
|
|
||||||
public BigInteger getDeptId() {
|
public BigInteger getDeptId() {
|
||||||
return deptId;
|
return deptId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package tech.easyflow.ai.mcp;
|
||||||
|
|
||||||
|
import tech.easyflow.common.util.StringUtil;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
import java.util.Locale;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MCP 连接方式。
|
||||||
|
*/
|
||||||
|
public enum McpTransportType {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 标准输入输出进程通信。
|
||||||
|
*/
|
||||||
|
STDIO("stdio"),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* HTTP SSE 通信。
|
||||||
|
*/
|
||||||
|
SSE("http-sse"),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Streamable HTTP 通信。
|
||||||
|
*/
|
||||||
|
HTTP("http-stream");
|
||||||
|
|
||||||
|
private final String value;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 MCP 连接方式。
|
||||||
|
*
|
||||||
|
* @param value 配置值
|
||||||
|
*/
|
||||||
|
McpTransportType(String value) {
|
||||||
|
this.value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取配置值。
|
||||||
|
*
|
||||||
|
* @return 配置值
|
||||||
|
*/
|
||||||
|
public String getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析连接方式。
|
||||||
|
*
|
||||||
|
* @param value 连接方式文本
|
||||||
|
* @return MCP 连接方式
|
||||||
|
*/
|
||||||
|
public static McpTransportType from(String value) {
|
||||||
|
if (StringUtil.noText(value)) {
|
||||||
|
return STDIO;
|
||||||
|
}
|
||||||
|
String normalized = value.trim().toLowerCase(Locale.ROOT);
|
||||||
|
return switch (normalized) {
|
||||||
|
case "stdio" -> STDIO;
|
||||||
|
case "sse", "http-sse" -> SSE;
|
||||||
|
case "http", "http-stream", "streamable-http" -> HTTP;
|
||||||
|
default -> throw new BusinessException("不支持的 MCP 连接方式: " + value);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package tech.easyflow.ai.service;
|
package tech.easyflow.ai.service;
|
||||||
|
|
||||||
import com.easyagents.core.model.chat.tool.Tool;
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
|
import com.easyagents.mcp.client.McpEnvironmentCheckResult;
|
||||||
import com.mybatisflex.core.paginate.Page;
|
import com.mybatisflex.core.paginate.Page;
|
||||||
import com.mybatisflex.core.service.IService;
|
import com.mybatisflex.core.service.IService;
|
||||||
import tech.easyflow.ai.entity.BotMcp;
|
import tech.easyflow.ai.entity.BotMcp;
|
||||||
@@ -30,4 +31,6 @@ public interface McpService extends IService<Mcp> {
|
|||||||
Mcp getMcpTools(String id);
|
Mcp getMcpTools(String id);
|
||||||
|
|
||||||
Page<Mcp> pageTools(Page<Mcp> mcpPage);
|
Page<Mcp> pageTools(Page<Mcp> mcpPage);
|
||||||
|
|
||||||
|
McpEnvironmentCheckResult checkMcp(String configJson);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import tech.easyflow.ai.mapper.FaqItemMapper;
|
|||||||
import tech.easyflow.ai.rag.KnowledgeRetrievalRequest;
|
import tech.easyflow.ai.rag.KnowledgeRetrievalRequest;
|
||||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.ai.utils.CustomBeanUtils;
|
import tech.easyflow.ai.utils.CustomBeanUtils;
|
||||||
import tech.easyflow.ai.utils.RegexUtils;
|
import tech.easyflow.ai.utils.RegexUtils;
|
||||||
import tech.easyflow.common.util.StringUtil;
|
import tech.easyflow.common.util.StringUtil;
|
||||||
@@ -283,34 +284,38 @@ public class DocumentCollectionServiceImpl extends ServiceImpl<DocumentCollectio
|
|||||||
throw new BusinessException("知识库没有配置向量库");
|
throw new BusinessException("知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
|
|
||||||
Model model = llmService.getModelInstance(documentCollection.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = llmService.getModelInstance(documentCollection.getVectorEmbedModelId());
|
||||||
throw new BusinessException("知识库没有配置向量模型");
|
if (model == null) {
|
||||||
}
|
throw new BusinessException("知识库没有配置向量模型");
|
||||||
|
}
|
||||||
|
|
||||||
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
documentStore.setEmbeddingModel(model.toEmbeddingModel());
|
||||||
SearchWrapper wrapper = new SearchWrapper();
|
SearchWrapper wrapper = new SearchWrapper();
|
||||||
wrapper.setMaxResults(docRecallMaxNum);
|
wrapper.setMaxResults(docRecallMaxNum);
|
||||||
if (minSimilarity != null) {
|
if (minSimilarity != null) {
|
||||||
wrapper.setMinScore((double) minSimilarity);
|
wrapper.setMinScore((double) minSimilarity);
|
||||||
}
|
}
|
||||||
wrapper.setText(keyword);
|
wrapper.setText(keyword);
|
||||||
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(documentCollection.getVectorStoreCollection());
|
StoreOptions options = StoreOptions.ofCollectionName(documentCollection.getVectorStoreCollection());
|
||||||
options.setIndexName(documentCollection.getVectorStoreCollection());
|
options.setIndexName(documentCollection.getVectorStoreCollection());
|
||||||
List<Document> documents = documentStore.search(wrapper, options);
|
List<Document> documents = documentStore.search(wrapper, options);
|
||||||
List<Document> result = documents == null ? Collections.<Document>emptyList() : documents;
|
List<Document> result = documents == null ? Collections.<Document>emptyList() : documents;
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Knowledge vector search completed, knowledgeId={}, collectionName={}, query={}, limit={}, minSimilarity={}, hitCount={}, hits={}",
|
"Knowledge vector search completed, knowledgeId={}, collectionName={}, query={}, limit={}, minSimilarity={}, hitCount={}, hits={}",
|
||||||
documentCollection.getId(),
|
documentCollection.getId(),
|
||||||
documentCollection.getVectorStoreCollection(),
|
documentCollection.getVectorStoreCollection(),
|
||||||
keyword,
|
keyword,
|
||||||
docRecallMaxNum,
|
docRecallMaxNum,
|
||||||
minSimilarity,
|
minSimilarity,
|
||||||
result.size(),
|
result.size(),
|
||||||
summarizeDocuments(result)
|
summarizeDocuments(result)
|
||||||
);
|
);
|
||||||
return result;
|
return result;
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Document> searchKeywordDocuments(DocumentCollection documentCollection, String keyword, int docRecallMaxNum) {
|
private List<Document> searchKeywordDocuments(DocumentCollection documentCollection, String keyword, int docRecallMaxNum) {
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ import tech.easyflow.ai.service.DocumentChunkService;
|
|||||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||||
import tech.easyflow.ai.service.DocumentService;
|
import tech.easyflow.ai.service.DocumentService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.common.ai.rag.ExcelDocumentSplitter;
|
import tech.easyflow.common.ai.rag.ExcelDocumentSplitter;
|
||||||
import tech.easyflow.common.domain.Result;
|
import tech.easyflow.common.domain.Result;
|
||||||
import tech.easyflow.common.filestorage.FileStorageService;
|
import tech.easyflow.common.filestorage.FileStorageService;
|
||||||
@@ -154,34 +155,38 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Model model = modelService.getById(knowledge.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getById(knowledge.getVectorEmbedModelId());
|
||||||
return false;
|
if (model == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// 设置向量模型
|
||||||
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
||||||
|
embeddingOptions.setModel(model.getModelName());
|
||||||
|
options.setEmbeddingOptions(embeddingOptions);
|
||||||
|
options.setCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
// 查询文本分割表tb_document_chunk中对应的有哪些数据,找出来删除
|
||||||
|
QueryWrapper queryWrapper = QueryWrapper.create()
|
||||||
|
.select(DOCUMENT_CHUNK.ID).eq(DocumentChunk::getDocumentId, id);
|
||||||
|
List<BigInteger> chunkIds = documentChunkMapper.selectListByQueryAs(queryWrapper, BigInteger.class);
|
||||||
|
documentStore.delete(chunkIds, options);
|
||||||
|
// 删除搜索引擎中的数据
|
||||||
|
DocumentSearcher searcher = searcherFactory.getSearcher();
|
||||||
|
if (searcher != null) {
|
||||||
|
chunkIds.forEach(searcher::deleteDocument);
|
||||||
|
}
|
||||||
|
int ck = documentChunkMapper.deleteByQuery(QueryWrapper.create().eq(DocumentChunk::getDocumentId, id));
|
||||||
|
if (ck < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// 再删除指定路径下的文件
|
||||||
|
Document document = documentMapper.selectOneByQuery(queryWrapperDocument);
|
||||||
|
storageService.delete(document.getDocumentPath());
|
||||||
|
return true;
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
// 设置向量模型
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
|
||||||
embeddingOptions.setModel(model.getModelName());
|
|
||||||
options.setEmbeddingOptions(embeddingOptions);
|
|
||||||
options.setCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
// 查询文本分割表tb_document_chunk中对应的有哪些数据,找出来删除
|
|
||||||
QueryWrapper queryWrapper = QueryWrapper.create()
|
|
||||||
.select(DOCUMENT_CHUNK.ID).eq(DocumentChunk::getDocumentId, id);
|
|
||||||
List<BigInteger> chunkIds = documentChunkMapper.selectListByQueryAs(queryWrapper, BigInteger.class);
|
|
||||||
documentStore.delete(chunkIds, options);
|
|
||||||
// 删除搜索引擎中的数据
|
|
||||||
DocumentSearcher searcher = searcherFactory.getSearcher();
|
|
||||||
if (searcher != null) {
|
|
||||||
chunkIds.forEach(searcher::deleteDocument);
|
|
||||||
}
|
|
||||||
int ck = documentChunkMapper.deleteByQuery(QueryWrapper.create().eq(DocumentChunk::getDocumentId, id));
|
|
||||||
if (ck < 0) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// 再删除指定路径下的文件
|
|
||||||
Document document = documentMapper.selectOneByQuery(queryWrapperDocument);
|
|
||||||
storageService.delete(document.getDocumentPath());
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -286,8 +291,8 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
}
|
}
|
||||||
|
|
||||||
StoreExecutionContext storeContext = prepareStoreContext(document);
|
StoreExecutionContext storeContext = prepareStoreContext(document);
|
||||||
storeDocumentChunks(storeContext, validChunks);
|
|
||||||
try {
|
try {
|
||||||
|
storeDocumentChunks(storeContext, validChunks);
|
||||||
persistDocumentWithChunks(document, validChunks);
|
persistDocumentWithChunks(document, validChunks);
|
||||||
updateKnowledgeAfterStore(storeContext);
|
updateKnowledgeAfterStore(storeContext);
|
||||||
return Result.ok();
|
return Result.ok();
|
||||||
@@ -296,14 +301,20 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
rollbackStoredChunks(storeContext, validChunks);
|
rollbackStoredChunks(storeContext, validChunks);
|
||||||
Log.error("保存文档失败: documentId={}, title={}", document.getId(), document.getTitle(), e);
|
Log.error("保存文档失败: documentId={}, title={}", document.getId(), document.getTitle(), e);
|
||||||
throw new BusinessException("保存失败:" + e.getMessage());
|
throw new BusinessException("保存失败:" + e.getMessage());
|
||||||
|
} finally {
|
||||||
|
closeStoreContext(storeContext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Boolean storeDocument(Document entity, List<DocumentChunk> documentChunks) {
|
protected Boolean storeDocument(Document entity, List<DocumentChunk> documentChunks) {
|
||||||
StoreExecutionContext storeContext = prepareStoreContext(entity);
|
StoreExecutionContext storeContext = prepareStoreContext(entity);
|
||||||
storeDocumentChunks(storeContext, documentChunks);
|
try {
|
||||||
updateKnowledgeAfterStore(storeContext);
|
storeDocumentChunks(storeContext, documentChunks);
|
||||||
return true;
|
updateKnowledgeAfterStore(storeContext);
|
||||||
|
return true;
|
||||||
|
} finally {
|
||||||
|
closeStoreContext(storeContext);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -430,14 +441,16 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
}
|
}
|
||||||
|
|
||||||
StoreExecutionContext storeContext = prepareStoreContext(document);
|
StoreExecutionContext storeContext = prepareStoreContext(document);
|
||||||
storeDocumentChunks(storeContext, session.getDocumentChunks());
|
|
||||||
try {
|
try {
|
||||||
|
storeDocumentChunks(storeContext, session.getDocumentChunks());
|
||||||
persistDocumentWithChunks(document, session.getDocumentChunks());
|
persistDocumentWithChunks(document, session.getDocumentChunks());
|
||||||
updateKnowledgeAfterStore(storeContext);
|
updateKnowledgeAfterStore(storeContext);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
cleanupPersistedDocument(document);
|
cleanupPersistedDocument(document);
|
||||||
rollbackStoredChunks(storeContext, session.getDocumentChunks());
|
rollbackStoredChunks(storeContext, session.getDocumentChunks());
|
||||||
throw new BusinessException("提交导入失败:" + e.getMessage());
|
throw new BusinessException("提交导入失败:" + e.getMessage());
|
||||||
|
} finally {
|
||||||
|
closeStoreContext(storeContext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -751,24 +764,28 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
throw new BusinessException("向量数据库配置错误");
|
throw new BusinessException("向量数据库配置错误");
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
|
if (model == null) {
|
||||||
|
throw new BusinessException("该知识库未配置大模型");
|
||||||
|
}
|
||||||
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
if (model == null) {
|
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
||||||
throw new BusinessException("该知识库未配置大模型");
|
embeddingOptions.setModel(model.getModelName());
|
||||||
|
embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel());
|
||||||
|
options.setEmbeddingOptions(embeddingOptions);
|
||||||
|
options.setIndexName(options.getCollectionName());
|
||||||
|
|
||||||
|
DocumentSearcher searcher = null;
|
||||||
|
searcher = searcherFactory.getSearcher();
|
||||||
|
return new StoreExecutionContext(knowledge, model, embeddingModel, documentStore, options, searcher);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
|
throw e;
|
||||||
}
|
}
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
|
||||||
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
|
||||||
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
|
||||||
embeddingOptions.setModel(model.getModelName());
|
|
||||||
embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel());
|
|
||||||
options.setEmbeddingOptions(embeddingOptions);
|
|
||||||
options.setIndexName(options.getCollectionName());
|
|
||||||
|
|
||||||
DocumentSearcher searcher = null;
|
|
||||||
searcher = searcherFactory.getSearcher();
|
|
||||||
return new StoreExecutionContext(knowledge, model, embeddingModel, documentStore, options, searcher);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void storeDocumentChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
|
private void storeDocumentChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
|
||||||
@@ -841,6 +858,13 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void closeStoreContext(StoreExecutionContext storeContext) {
|
||||||
|
if (storeContext == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(storeContext.documentStore);
|
||||||
|
}
|
||||||
|
|
||||||
private void persistDocumentWithChunks(Document document, List<DocumentChunk> chunks) {
|
private void persistDocumentWithChunks(Document document, List<DocumentChunk> chunks) {
|
||||||
this.getMapper().insert(document);
|
this.getMapper().insert(document);
|
||||||
AtomicInteger sort = new AtomicInteger(1);
|
AtomicInteger sort = new AtomicInteger(1);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package tech.easyflow.ai.service.impl;
|
package tech.easyflow.ai.service.impl;
|
||||||
|
|
||||||
|
import cn.dev33.satoken.stp.StpUtil;
|
||||||
import cn.idev.excel.EasyExcel;
|
import cn.idev.excel.EasyExcel;
|
||||||
import cn.idev.excel.ExcelWriter;
|
import cn.idev.excel.ExcelWriter;
|
||||||
import cn.idev.excel.FastExcel;
|
import cn.idev.excel.FastExcel;
|
||||||
@@ -8,7 +9,6 @@ import cn.idev.excel.metadata.data.ReadCellData;
|
|||||||
import cn.idev.excel.read.listener.ReadListener;
|
import cn.idev.excel.read.listener.ReadListener;
|
||||||
import cn.idev.excel.write.metadata.WriteSheet;
|
import cn.idev.excel.write.metadata.WriteSheet;
|
||||||
import cn.idev.excel.write.style.column.SimpleColumnWidthStyleStrategy;
|
import cn.idev.excel.write.style.column.SimpleColumnWidthStyleStrategy;
|
||||||
import cn.dev33.satoken.stp.StpUtil;
|
|
||||||
import com.easyagents.core.model.embedding.EmbeddingModel;
|
import com.easyagents.core.model.embedding.EmbeddingModel;
|
||||||
import com.easyagents.core.model.embedding.EmbeddingOptions;
|
import com.easyagents.core.model.embedding.EmbeddingOptions;
|
||||||
import com.easyagents.core.store.DocumentStore;
|
import com.easyagents.core.store.DocumentStore;
|
||||||
@@ -20,15 +20,15 @@ import com.mybatisflex.core.query.QueryWrapper;
|
|||||||
import com.mybatisflex.spring.service.impl.ServiceImpl;
|
import com.mybatisflex.spring.service.impl.ServiceImpl;
|
||||||
import org.jsoup.Jsoup;
|
import org.jsoup.Jsoup;
|
||||||
import org.jsoup.nodes.Element;
|
import org.jsoup.nodes.Element;
|
||||||
import org.jsoup.select.Elements;
|
|
||||||
import org.jsoup.safety.Safelist;
|
import org.jsoup.safety.Safelist;
|
||||||
|
import org.jsoup.select.Elements;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
|
||||||
import org.springframework.transaction.PlatformTransactionManager;
|
import org.springframework.transaction.PlatformTransactionManager;
|
||||||
import org.springframework.transaction.TransactionDefinition;
|
import org.springframework.transaction.TransactionDefinition;
|
||||||
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import org.springframework.transaction.support.TransactionTemplate;
|
import org.springframework.transaction.support.TransactionTemplate;
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
import org.springframework.web.multipart.MultipartFile;
|
||||||
import tech.easyflow.ai.config.SearcherFactory;
|
import tech.easyflow.ai.config.SearcherFactory;
|
||||||
@@ -40,6 +40,7 @@ import tech.easyflow.ai.service.DocumentCollectionService;
|
|||||||
import tech.easyflow.ai.service.FaqCategoryService;
|
import tech.easyflow.ai.service.FaqCategoryService;
|
||||||
import tech.easyflow.ai.service.FaqItemService;
|
import tech.easyflow.ai.service.FaqItemService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.ai.vo.FaqImportErrorRowVo;
|
import tech.easyflow.ai.vo.FaqImportErrorRowVo;
|
||||||
import tech.easyflow.ai.vo.FaqImportResultVo;
|
import tech.easyflow.ai.vo.FaqImportResultVo;
|
||||||
import tech.easyflow.common.util.StringUtil;
|
import tech.easyflow.common.util.StringUtil;
|
||||||
@@ -348,29 +349,46 @@ public class FaqItemServiceImpl extends ServiceImpl<FaqItemMapper, FaqItem> impl
|
|||||||
|
|
||||||
private void storeToVector(DocumentCollection collection, FaqItem entity, boolean isUpdate) {
|
private void storeToVector(DocumentCollection collection, FaqItem entity, boolean isUpdate) {
|
||||||
PreparedStore preparedStore = prepareStore(collection);
|
PreparedStore preparedStore = prepareStore(collection);
|
||||||
com.easyagents.core.document.Document doc = toSearchDocument(entity);
|
try {
|
||||||
StoreResult result = isUpdate
|
com.easyagents.core.document.Document doc = toSearchDocument(entity);
|
||||||
? preparedStore.documentStore.update(doc, preparedStore.storeOptions)
|
StoreResult result = isUpdate
|
||||||
: preparedStore.documentStore.store(Collections.singletonList(doc), preparedStore.storeOptions);
|
? preparedStore.documentStore.update(doc, preparedStore.storeOptions)
|
||||||
if (result == null || !result.isSuccess()) {
|
: preparedStore.documentStore.store(Collections.singletonList(doc), preparedStore.storeOptions);
|
||||||
throw new BusinessException("FAQ向量化失败");
|
if (result == null || !result.isSuccess()) {
|
||||||
}
|
throw new BusinessException("FAQ向量化失败");
|
||||||
|
|
||||||
DocumentSearcher searcher = searcherFactory.getSearcher();
|
|
||||||
if (searcher != null) {
|
|
||||||
if (isUpdate) {
|
|
||||||
searcher.deleteDocument(entity.getId());
|
|
||||||
}
|
}
|
||||||
searcher.addDocument(doc);
|
|
||||||
|
DocumentSearcher searcher = searcherFactory.getSearcher();
|
||||||
|
if (searcher != null) {
|
||||||
|
if (isUpdate) {
|
||||||
|
searcher.deleteDocument(entity.getId());
|
||||||
|
}
|
||||||
|
searcher.addDocument(doc);
|
||||||
|
}
|
||||||
|
markCollectionEmbedded(collection, preparedStore.embeddingModel);
|
||||||
|
} catch (BusinessException e) {
|
||||||
|
throw e;
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
LOG.error("FAQ vectorization failed. collectionId={}, faqId={}, isUpdate={}",
|
||||||
|
collection == null ? null : collection.getId(),
|
||||||
|
entity == null ? null : entity.getId(),
|
||||||
|
isUpdate,
|
||||||
|
e);
|
||||||
|
throw new BusinessException("FAQ向量化失败:请检查知识库绑定的向量模型、请求路径、维度或向量库配置");
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(preparedStore.documentStore);
|
||||||
}
|
}
|
||||||
markCollectionEmbedded(collection, preparedStore.embeddingModel);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void removeFromVector(DocumentCollection collection, FaqItem entity) {
|
private void removeFromVector(DocumentCollection collection, FaqItem entity) {
|
||||||
PreparedStore preparedStore = prepareStore(collection);
|
PreparedStore preparedStore = prepareStore(collection);
|
||||||
boolean deleteSuccess = deleteFromVectorStore(preparedStore.documentStore, preparedStore.storeOptions, entity.getId());
|
try {
|
||||||
if (!deleteSuccess) {
|
boolean deleteSuccess = deleteFromVectorStore(preparedStore.documentStore, preparedStore.storeOptions, entity.getId());
|
||||||
throw new BusinessException("FAQ向量删除失败");
|
if (!deleteSuccess) {
|
||||||
|
throw new BusinessException("FAQ向量删除失败");
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(preparedStore.documentStore);
|
||||||
}
|
}
|
||||||
|
|
||||||
DocumentSearcher searcher = searcherFactory.getSearcher();
|
DocumentSearcher searcher = searcherFactory.getSearcher();
|
||||||
@@ -413,20 +431,25 @@ public class FaqItemServiceImpl extends ServiceImpl<FaqItemMapper, FaqItem> impl
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
throw new BusinessException("向量数据库配置错误");
|
throw new BusinessException("向量数据库配置错误");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(collection.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(collection.getVectorEmbedModelId());
|
||||||
throw new BusinessException("该知识库未配置向量模型");
|
if (model == null) {
|
||||||
}
|
throw new BusinessException("该知识库未配置向量模型");
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
}
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
|
|
||||||
StoreOptions options = StoreOptions.ofCollectionName(collection.getVectorStoreCollection());
|
StoreOptions options = StoreOptions.ofCollectionName(collection.getVectorStoreCollection());
|
||||||
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
EmbeddingOptions embeddingOptions = new EmbeddingOptions();
|
||||||
embeddingOptions.setModel(model.getModelName());
|
embeddingOptions.setModel(model.getModelName());
|
||||||
embeddingOptions.setDimensions(collection.getDimensionOfVectorModel());
|
embeddingOptions.setDimensions(collection.getDimensionOfVectorModel());
|
||||||
options.setEmbeddingOptions(embeddingOptions);
|
options.setEmbeddingOptions(embeddingOptions);
|
||||||
options.setIndexName(options.getCollectionName());
|
options.setIndexName(options.getCollectionName());
|
||||||
return new PreparedStore(documentStore, options, embeddingModel);
|
return new PreparedStore(documentStore, options, embeddingModel);
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private com.easyagents.core.document.Document toSearchDocument(FaqItem entity) {
|
private com.easyagents.core.document.Document toSearchDocument(FaqItem entity) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import tech.easyflow.ai.service.DocumentCollectionService;
|
|||||||
import tech.easyflow.ai.service.FaqItemService;
|
import tech.easyflow.ai.service.FaqItemService;
|
||||||
import tech.easyflow.ai.service.KnowledgeEmbeddingService;
|
import tech.easyflow.ai.service.KnowledgeEmbeddingService;
|
||||||
import tech.easyflow.ai.service.ModelService;
|
import tech.easyflow.ai.service.ModelService;
|
||||||
|
import tech.easyflow.ai.support.DocumentStoreLifecycleSupport;
|
||||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
@@ -50,20 +51,24 @@ public class KnowledgeEmbeddingServiceImpl implements KnowledgeEmbeddingService
|
|||||||
if (documentStore == null) {
|
if (documentStore == null) {
|
||||||
throw new BusinessException("知识库没有配置向量库");
|
throw new BusinessException("知识库没有配置向量库");
|
||||||
}
|
}
|
||||||
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
try {
|
||||||
if (model == null) {
|
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
|
||||||
throw new BusinessException("知识库没有配置向量模型");
|
if (model == null) {
|
||||||
}
|
throw new BusinessException("知识库没有配置向量模型");
|
||||||
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
}
|
||||||
documentStore.setEmbeddingModel(embeddingModel);
|
EmbeddingModel embeddingModel = model.toEmbeddingModel();
|
||||||
StoreOptions storeOptions = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
documentStore.setEmbeddingModel(embeddingModel);
|
||||||
storeOptions.setIndexName(knowledge.getVectorStoreCollection());
|
StoreOptions storeOptions = StoreOptions.ofCollectionName(knowledge.getVectorStoreCollection());
|
||||||
|
storeOptions.setIndexName(knowledge.getVectorStoreCollection());
|
||||||
|
|
||||||
if (knowledge.isFaqCollection()) {
|
if (knowledge.isFaqCollection()) {
|
||||||
rebuildFaqVectors(knowledge, documentStore, storeOptions, embeddingModel);
|
rebuildFaqVectors(knowledge, documentStore, storeOptions, embeddingModel);
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
rebuildDocumentVectors(knowledge, documentStore, storeOptions, embeddingModel);
|
||||||
|
} finally {
|
||||||
|
DocumentStoreLifecycleSupport.closeQuietly(documentStore);
|
||||||
}
|
}
|
||||||
rebuildDocumentVectors(knowledge, documentStore, storeOptions, embeddingModel);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void rebuildDocumentVectors(
|
private void rebuildDocumentVectors(
|
||||||
@@ -153,4 +158,3 @@ public class KnowledgeEmbeddingServiceImpl implements KnowledgeEmbeddingService
|
|||||||
documentCollectionService.updateById(update);
|
documentCollectionService.updateById(update);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package tech.easyflow.ai.service.impl;
|
|||||||
import com.easyagents.core.model.chat.tool.Parameter;
|
import com.easyagents.core.model.chat.tool.Parameter;
|
||||||
import com.easyagents.core.model.chat.tool.Tool;
|
import com.easyagents.core.model.chat.tool.Tool;
|
||||||
import com.easyagents.mcp.client.McpClientManager;
|
import com.easyagents.mcp.client.McpClientManager;
|
||||||
|
import com.easyagents.mcp.client.McpEnvironmentCheckResult;
|
||||||
|
import com.easyagents.mcp.client.McpEnvironmentChecker;
|
||||||
import com.alibaba.fastjson2.JSON;
|
import com.alibaba.fastjson2.JSON;
|
||||||
import com.alibaba.fastjson2.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import com.mybatisflex.core.paginate.Page;
|
import com.mybatisflex.core.paginate.Page;
|
||||||
@@ -16,6 +18,7 @@ import tech.easyflow.ai.easyagents.tool.McpTool;
|
|||||||
import tech.easyflow.ai.entity.BotMcp;
|
import tech.easyflow.ai.entity.BotMcp;
|
||||||
import tech.easyflow.ai.entity.Mcp;
|
import tech.easyflow.ai.entity.Mcp;
|
||||||
import tech.easyflow.ai.mapper.McpMapper;
|
import tech.easyflow.ai.mapper.McpMapper;
|
||||||
|
import tech.easyflow.ai.mcp.McpTransportType;
|
||||||
import tech.easyflow.ai.service.McpService;
|
import tech.easyflow.ai.service.McpService;
|
||||||
import tech.easyflow.ai.utils.CommonFiledUtil;
|
import tech.easyflow.ai.utils.CommonFiledUtil;
|
||||||
import tech.easyflow.common.constant.enums.EnumRes;
|
import tech.easyflow.common.constant.enums.EnumRes;
|
||||||
@@ -37,7 +40,8 @@ import java.util.*;
|
|||||||
@Service
|
@Service
|
||||||
public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpService {
|
public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpService {
|
||||||
private final McpClientManager mcpClientManager = McpClientManager.getInstance();
|
private final McpClientManager mcpClientManager = McpClientManager.getInstance();
|
||||||
protected Logger Log = LoggerFactory.getLogger(DocumentServiceImpl.class);
|
private final McpEnvironmentChecker mcpEnvironmentChecker = new McpEnvironmentChecker();
|
||||||
|
protected Logger Log = LoggerFactory.getLogger(McpServiceImpl.class);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Result<?> saveMcp(Mcp entity) {
|
public Result<?> saveMcp(Mcp entity) {
|
||||||
@@ -49,6 +53,8 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
if (!StringUtil.hasText(serverName)) {
|
if (!StringUtil.hasText(serverName)) {
|
||||||
return Result.fail("未找到mcp服务名称", serverName);
|
return Result.fail("未找到mcp服务名称", serverName);
|
||||||
}
|
}
|
||||||
|
entity.setTransportType(getFirstMcpTransportType(entity.getConfigJson()));
|
||||||
|
entity.setApprovalRequired(Boolean.TRUE.equals(entity.getApprovalRequired()));
|
||||||
try {
|
try {
|
||||||
mcpClientManager.registerFromJson(entity.getConfigJson());
|
mcpClientManager.registerFromJson(entity.getConfigJson());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -79,6 +85,8 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
if (!StringUtil.hasText(serverName)) {
|
if (!StringUtil.hasText(serverName)) {
|
||||||
return Result.fail("未找到mcp服务名称", serverName);
|
return Result.fail("未找到mcp服务名称", serverName);
|
||||||
}
|
}
|
||||||
|
entity.setTransportType(getFirstMcpTransportType(entity.getConfigJson()));
|
||||||
|
entity.setApprovalRequired(Boolean.TRUE.equals(entity.getApprovalRequired()));
|
||||||
if (entity.getStatus()) {
|
if (entity.getStatus()) {
|
||||||
try {
|
try {
|
||||||
mcpClientManager.registerFromJson(entity.getConfigJson());
|
mcpClientManager.registerFromJson(entity.getConfigJson());
|
||||||
@@ -121,6 +129,7 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
records.forEach(mcp -> {
|
records.forEach(mcp -> {
|
||||||
boolean clientOnline = mcpClientManager.isClientOnline(getFirstMcpServerName(mcp.getConfigJson()));
|
boolean clientOnline = mcpClientManager.isClientOnline(getFirstMcpServerName(mcp.getConfigJson()));
|
||||||
mcp.setClientOnline(clientOnline);
|
mcp.setClientOnline(clientOnline);
|
||||||
|
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
page.getData().setRecords(records);
|
page.getData().setRecords(records);
|
||||||
@@ -130,6 +139,9 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
@Override
|
@Override
|
||||||
public Mcp getMcpTools(String id) {
|
public Mcp getMcpTools(String id) {
|
||||||
Mcp mcp = this.getById(id);
|
Mcp mcp = this.getById(id);
|
||||||
|
if (mcp != null) {
|
||||||
|
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||||
|
}
|
||||||
if (mcp != null && mcp.getStatus()) {
|
if (mcp != null && mcp.getStatus()) {
|
||||||
McpSyncClient mcpClient = getMcpClient(mcp, mcpClientManager);
|
McpSyncClient mcpClient = getMcpClient(mcp, mcpClientManager);
|
||||||
List<McpSchema.Tool> tools = null;
|
List<McpSchema.Tool> tools = null;
|
||||||
@@ -209,9 +221,27 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
return firstServerName.orElse(null);
|
return firstServerName.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String getFirstMcpTransportType(String mcpJson) {
|
||||||
|
JSONObject rootJson = JSON.parseObject(mcpJson);
|
||||||
|
JSONObject mcpServersJson = rootJson.getJSONObject("mcpServers");
|
||||||
|
if (mcpServersJson == null || mcpServersJson.isEmpty()) {
|
||||||
|
return McpTransportType.STDIO.getValue();
|
||||||
|
}
|
||||||
|
Optional<String> firstServerName = mcpServersJson.keySet().stream().findFirst();
|
||||||
|
if (firstServerName.isEmpty()) {
|
||||||
|
return McpTransportType.STDIO.getValue();
|
||||||
|
}
|
||||||
|
JSONObject serverJson = mcpServersJson.getJSONObject(firstServerName.get());
|
||||||
|
if (serverJson == null) {
|
||||||
|
return McpTransportType.STDIO.getValue();
|
||||||
|
}
|
||||||
|
return McpTransportType.from(serverJson.getString("transport")).getValue();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Page<Mcp> pageTools(Page<Mcp> page) {
|
public Page<Mcp> pageTools(Page<Mcp> page) {
|
||||||
page.getRecords().forEach(mcp -> {
|
page.getRecords().forEach(mcp -> {
|
||||||
|
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||||
// mcp 未启用,不查询工具
|
// mcp 未启用,不查询工具
|
||||||
if (!mcp.getStatus()) {
|
if (!mcp.getStatus()) {
|
||||||
return;
|
return;
|
||||||
@@ -235,6 +265,11 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
return page;
|
return page;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public McpEnvironmentCheckResult checkMcp(String configJson) {
|
||||||
|
return mcpEnvironmentChecker.check(configJson);
|
||||||
|
}
|
||||||
|
|
||||||
private Result<?> validateMcpConfig(Mcp entity) {
|
private Result<?> validateMcpConfig(Mcp entity) {
|
||||||
if (entity == null || !StringUtil.hasText(entity.getConfigJson())) {
|
if (entity == null || !StringUtil.hasText(entity.getConfigJson())) {
|
||||||
Log.error("MCP 配置不能为空");
|
Log.error("MCP 配置不能为空");
|
||||||
@@ -242,4 +277,14 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
|||||||
}
|
}
|
||||||
return Result.ok();
|
return Result.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private String resolveMcpTransportType(Mcp mcp) {
|
||||||
|
if (mcp == null) {
|
||||||
|
return McpTransportType.STDIO.getValue();
|
||||||
|
}
|
||||||
|
if (StringUtil.hasText(mcp.getTransportType())) {
|
||||||
|
return McpTransportType.from(mcp.getTransportType()).getValue();
|
||||||
|
}
|
||||||
|
return getFirstMcpTransportType(mcp.getConfigJson());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package tech.easyflow.ai.support;
|
||||||
|
|
||||||
|
import com.easyagents.core.store.DocumentStore;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文档向量库生命周期辅助工具。
|
||||||
|
*/
|
||||||
|
public final class DocumentStoreLifecycleSupport {
|
||||||
|
|
||||||
|
private static final Logger LOG = LoggerFactory.getLogger(DocumentStoreLifecycleSupport.class);
|
||||||
|
|
||||||
|
private DocumentStoreLifecycleSupport() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 关闭支持关闭语义的文档向量库。
|
||||||
|
*
|
||||||
|
* @param documentStore 文档向量库实例
|
||||||
|
*/
|
||||||
|
public static void closeQuietly(DocumentStore documentStore) {
|
||||||
|
if (!(documentStore instanceof AutoCloseable)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
((AutoCloseable) documentStore).close();
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.warn("关闭文档向量库连接失败: store={}", documentStore.getClass().getSimpleName(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package tech.easyflow.ai.config;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.springframework.boot.actuate.health.Health;
|
||||||
|
|
||||||
|
import java.time.Clock;
|
||||||
|
import java.time.Duration;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.time.ZoneId;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 健康检查短缓存测试。
|
||||||
|
*/
|
||||||
|
public class CachedHealthIndicatorSupportTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 TTL 内重复健康检查复用缓存。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void shouldReuseHealthWithinCacheTtl() {
|
||||||
|
RagHealthProperties properties = new RagHealthProperties();
|
||||||
|
properties.setCacheTtl(Duration.ofSeconds(5));
|
||||||
|
MutableClock clock = new MutableClock();
|
||||||
|
CountingHealthIndicator indicator = new CountingHealthIndicator(properties, clock);
|
||||||
|
|
||||||
|
indicator.cachedHealth();
|
||||||
|
indicator.cachedHealth();
|
||||||
|
|
||||||
|
Assert.assertEquals(1, indicator.count());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 TTL 过期后重新执行健康检查。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void shouldRefreshHealthAfterCacheExpired() {
|
||||||
|
RagHealthProperties properties = new RagHealthProperties();
|
||||||
|
properties.setCacheTtl(Duration.ofSeconds(5));
|
||||||
|
MutableClock clock = new MutableClock();
|
||||||
|
CountingHealthIndicator indicator = new CountingHealthIndicator(properties, clock);
|
||||||
|
|
||||||
|
indicator.cachedHealth();
|
||||||
|
clock.plus(Duration.ofSeconds(6));
|
||||||
|
indicator.cachedHealth();
|
||||||
|
|
||||||
|
Assert.assertEquals(2, indicator.count());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class CountingHealthIndicator extends CachedHealthIndicatorSupport {
|
||||||
|
|
||||||
|
private final AtomicInteger counter = new AtomicInteger();
|
||||||
|
|
||||||
|
private CountingHealthIndicator(RagHealthProperties properties, Clock clock) {
|
||||||
|
super(properties, clock);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Health doHealthCheck() {
|
||||||
|
counter.incrementAndGet();
|
||||||
|
return Health.up().build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private int count() {
|
||||||
|
return counter.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class MutableClock extends Clock {
|
||||||
|
|
||||||
|
private Instant instant = Instant.parse("2026-05-25T00:00:00Z");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ZoneId getZone() {
|
||||||
|
return ZoneId.of("UTC");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Clock withZone(ZoneId zone) {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Instant instant() {
|
||||||
|
return instant;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void plus(Duration duration) {
|
||||||
|
instant = instant.plus(duration);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package tech.easyflow.ai.documentimport.task;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.mockito.ArgumentMatchers;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||||
|
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||||
|
import tech.easyflow.ai.entity.Document;
|
||||||
|
import tech.easyflow.ai.mapper.DocumentMapper;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.math.BigInteger;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link DocumentImportTaskStatusStreamService} 回归测试。
|
||||||
|
*/
|
||||||
|
public class DocumentImportTaskStatusStreamServiceTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证文档状态变更会向 Redis 广播文档 ID。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void publishAfterCommitShouldBroadcastDocumentId() throws Exception {
|
||||||
|
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||||
|
DocumentImportTaskStatusStreamService service = new DocumentImportTaskStatusStreamService();
|
||||||
|
setField(service, "documentMapper", mockDocumentMapper());
|
||||||
|
setField(service, "sseThreadPool", directExecutor());
|
||||||
|
setField(service, "stringRedisTemplate", redisTemplate);
|
||||||
|
setField(service, "statusBroadcastProperties", statusBroadcastProperties());
|
||||||
|
|
||||||
|
service.publishAfterCommit(BigInteger.valueOf(101));
|
||||||
|
|
||||||
|
Mockito.verify(redisTemplate).convertAndSend("easyflow:document-import:test-status", "101");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证收到 Redis 广播后会重新查询文档状态。
|
||||||
|
*
|
||||||
|
* @throws Exception 反射注入异常
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void publishLocalShouldReloadDocumentStatus() throws Exception {
|
||||||
|
AtomicReference<BigInteger> selectedIdRef = new AtomicReference<BigInteger>();
|
||||||
|
DocumentImportTaskStatusStreamService service = new DocumentImportTaskStatusStreamService();
|
||||||
|
setField(service, "documentMapper", mockDocumentMapper(selectedIdRef));
|
||||||
|
setField(service, "sseThreadPool", directExecutor());
|
||||||
|
setField(service, "stringRedisTemplate", Mockito.mock(StringRedisTemplate.class));
|
||||||
|
setField(service, "statusBroadcastProperties", statusBroadcastProperties());
|
||||||
|
|
||||||
|
service.publishLocal(BigInteger.valueOf(202));
|
||||||
|
|
||||||
|
Assert.assertEquals(BigInteger.valueOf(202), selectedIdRef.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
private DocumentImportStatusBroadcastProperties statusBroadcastProperties() {
|
||||||
|
DocumentImportStatusBroadcastProperties properties = new DocumentImportStatusBroadcastProperties();
|
||||||
|
properties.setStatusBroadcastChannel("easyflow:document-import:test-status");
|
||||||
|
return properties;
|
||||||
|
}
|
||||||
|
|
||||||
|
private DocumentMapper mockDocumentMapper() {
|
||||||
|
return mockDocumentMapper(new AtomicReference<BigInteger>());
|
||||||
|
}
|
||||||
|
|
||||||
|
private DocumentMapper mockDocumentMapper(AtomicReference<BigInteger> selectedIdRef) {
|
||||||
|
DocumentMapper mapper = Mockito.mock(DocumentMapper.class);
|
||||||
|
Mockito.when(mapper.selectOneById(ArgumentMatchers.any())).thenAnswer(invocation -> {
|
||||||
|
Object id = invocation.getArgument(0);
|
||||||
|
selectedIdRef.set((BigInteger) id);
|
||||||
|
Document document = new Document();
|
||||||
|
document.setId((BigInteger) id);
|
||||||
|
document.setCollectionId(BigInteger.valueOf(1));
|
||||||
|
return document;
|
||||||
|
});
|
||||||
|
return mapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ThreadPoolTaskExecutor directExecutor() {
|
||||||
|
ThreadPoolTaskExecutor executor = Mockito.mock(ThreadPoolTaskExecutor.class);
|
||||||
|
Mockito.doAnswer(invocation -> {
|
||||||
|
Runnable runnable = invocation.getArgument(0);
|
||||||
|
runnable.run();
|
||||||
|
return null;
|
||||||
|
}).when(executor).execute(ArgumentMatchers.any(Runnable.class));
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||||
|
Field field = DocumentImportTaskStatusStreamService.class.getDeclaredField(fieldName);
|
||||||
|
field.setAccessible(true);
|
||||||
|
field.set(target, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package tech.easyflow.ai.mcp;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import tech.easyflow.ai.service.impl.McpServiceImpl;
|
||||||
|
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* {@link McpTransportType} 单元测试。
|
||||||
|
*/
|
||||||
|
public class McpTransportTypeTest {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 应兼容解析 MCP 配置中常见的连接方式文本。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void fromShouldParseSupportedTransportTypes() {
|
||||||
|
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from("stdio"));
|
||||||
|
Assert.assertEquals(McpTransportType.SSE, McpTransportType.from("sse"));
|
||||||
|
Assert.assertEquals(McpTransportType.SSE, McpTransportType.from("http-sse"));
|
||||||
|
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("http"));
|
||||||
|
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("http-stream"));
|
||||||
|
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("streamable-http"));
|
||||||
|
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from(null));
|
||||||
|
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from(" "));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 应从 MCP 配置 JSON 中推断首个 server 的连接方式。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void getFirstMcpTransportTypeShouldInferFromConfigJson() {
|
||||||
|
Assert.assertEquals("stdio", McpServiceImpl.getFirstMcpTransportType("""
|
||||||
|
{"mcpServers":{"everything":{"command":"npx","args":["-y","@modelcontextprotocol/server-everything"]}}}
|
||||||
|
"""));
|
||||||
|
Assert.assertEquals("http-sse", McpServiceImpl.getFirstMcpTransportType("""
|
||||||
|
{"mcpServers":{"remote":{"transport":"http-sse","url":"http://127.0.0.1:3000/sse"}}}
|
||||||
|
"""));
|
||||||
|
Assert.assertEquals("http-stream", McpServiceImpl.getFirstMcpTransportType("""
|
||||||
|
{"mcpServers":{"remote":{"transport":"http-stream","url":"http://127.0.0.1:3000/mcp"}}}
|
||||||
|
"""));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 不支持的连接方式应直接失败,避免保存无法启动的 MCP 配置。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void getFirstMcpTransportTypeShouldRejectUnsupportedTransportType() {
|
||||||
|
try {
|
||||||
|
McpServiceImpl.getFirstMcpTransportType("""
|
||||||
|
{"mcpServers":{"remote":{"transport":"websocket","url":"ws://127.0.0.1:3000/mcp"}}}
|
||||||
|
""");
|
||||||
|
Assert.fail("expected BusinessException");
|
||||||
|
} catch (BusinessException exception) {
|
||||||
|
Assert.assertTrue(exception.getMessage().contains("不支持的 MCP 连接方式"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -70,4 +70,25 @@ public class ChatAssistantAccumulatorTest {
|
|||||||
Assert.assertEquals(1, secondToolCalls.size());
|
Assert.assertEquals(1, secondToolCalls.size());
|
||||||
Assert.assertEquals("call-2", secondToolCalls.get(0).get("id"));
|
Assert.assertEquals("call-2", secondToolCalls.get(0).get("id"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具展示名应进入展示链和 assistant toolCalls,但不覆盖真实工具名。
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public void shouldKeepToolDisplayNameWithoutOverridingToolName() {
|
||||||
|
ChatAssistantAccumulator accumulator = new ChatAssistantAccumulator();
|
||||||
|
accumulator.appendToolCall("call-1", "mcp_123_search", "知识库 MCP - search", "{\"q\":\"java\"}");
|
||||||
|
accumulator.appendToolResult("call-1", "mcp_123_search", "知识库 MCP - search", "{\"ok\":true}");
|
||||||
|
|
||||||
|
Map<String, Object> payload = accumulator.buildPayload(null);
|
||||||
|
List<Map<String, Object>> chains = (List<Map<String, Object>>) payload.get("chains");
|
||||||
|
List<Map<String, Object>> messageChain = (List<Map<String, Object>>) payload.get("messageChain");
|
||||||
|
List<Map<String, Object>> toolCalls = (List<Map<String, Object>>) messageChain.get(0).get("toolCalls");
|
||||||
|
|
||||||
|
Assert.assertEquals("mcp_123_search", chains.get(0).get("name"));
|
||||||
|
Assert.assertEquals("知识库 MCP - search", chains.get(0).get("toolDisplayName"));
|
||||||
|
Assert.assertEquals("mcp_123_search", toolCalls.get(0).get("name"));
|
||||||
|
Assert.assertEquals("知识库 MCP - search", toolCalls.get(0).get("toolDisplayName"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,13 @@ package tech.easyflow.approval.config;
|
|||||||
|
|
||||||
import org.mybatis.spring.annotation.MapperScan;
|
import org.mybatis.spring.annotation.MapperScan;
|
||||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||||
|
import org.springframework.context.annotation.ComponentScan;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 审批模块配置。
|
* 审批模块配置。
|
||||||
*/
|
*/
|
||||||
@MapperScan("tech.easyflow.approval.mapper")
|
@MapperScan("tech.easyflow.approval.mapper")
|
||||||
|
@ComponentScan("tech.easyflow.approval")
|
||||||
@AutoConfiguration
|
@AutoConfiguration
|
||||||
public class ApprovalModuleConfig {
|
public class ApprovalModuleConfig {
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user