feat: 完成 Agent MCP 对接
- 增加 MCP 连接类型、环境检测接口和容器运行环境支持 - 将 Agent 编排改为绑定整体 MCP 并编译为 runtime McpSpec - 优化 MCP 工具展示、审批、草稿试运行和画布回显稳定性
This commit is contained in:
@@ -10,6 +10,8 @@ import com.easyagents.agent.runtime.knowledge.AgentKnowledgeSpec;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryCompressionParameter;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryPolicy;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryType;
|
||||
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||
import com.easyagents.agent.runtime.model.AgentGenerationOptions;
|
||||
import com.easyagents.agent.runtime.model.AgentModelProviderType;
|
||||
import com.easyagents.agent.runtime.model.AgentModelSpec;
|
||||
@@ -28,7 +30,6 @@ import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.enums.AgentToolType;
|
||||
import tech.easyflow.ai.easyagents.tool.ChatToolNameHelper;
|
||||
import tech.easyflow.ai.easyagents.tool.McpTool;
|
||||
import tech.easyflow.ai.easyagents.tool.WorkflowTool;
|
||||
import tech.easyflow.ai.easyagentsflow.support.PublishedWorkflowDefinitionIds;
|
||||
import tech.easyflow.ai.entity.*;
|
||||
@@ -40,6 +41,8 @@ import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
import javax.annotation.Resource;
|
||||
import java.math.BigInteger;
|
||||
import java.time.Duration;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
@@ -50,6 +53,7 @@ public class AgentDefinitionCompiler {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentDefinitionCompiler.class);
|
||||
private static final int LOG_TEXT_MAX_LENGTH = 500;
|
||||
private static final Pattern MCP_INPUT_PATTERN = Pattern.compile("\\$\\{input:([A-Za-z0-9_.-]+)}");
|
||||
|
||||
@Resource
|
||||
private ModelService modelService;
|
||||
@@ -210,16 +214,29 @@ public class AgentDefinitionCompiler {
|
||||
}
|
||||
List<AgentToolSpec> specs = new ArrayList<>();
|
||||
Map<String, com.easyagents.agent.runtime.tool.AgentToolInvoker> invokers = new LinkedHashMap<>();
|
||||
List<McpSpec> mcpSpecs = new ArrayList<>();
|
||||
Map<BigInteger, McpSpec> mcpSpecMap = new LinkedHashMap<>();
|
||||
for (AgentToolBinding binding : agent.getToolBindings()) {
|
||||
if (!Boolean.TRUE.equals(binding.getEnabled())) {
|
||||
continue;
|
||||
}
|
||||
AgentToolType type = AgentToolType.from(binding.getToolType());
|
||||
if (type == AgentToolType.MCP) {
|
||||
McpSpec mcpSpec = mcpSpecMap.computeIfAbsent(binding.getTargetId(),
|
||||
ignored -> buildMcpSpec(binding));
|
||||
applyMcpToolBinding(mcpSpec, binding);
|
||||
if (!mcpSpecs.contains(mcpSpec)) {
|
||||
mcpSpecs.add(mcpSpec);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Tool tool = buildTool(binding);
|
||||
AgentToolSpec spec = toToolSpec(tool, binding);
|
||||
specs.add(spec);
|
||||
invokers.put(spec.getName(), (arguments, context) -> invokeTool(tool, arguments));
|
||||
}
|
||||
definition.setToolSpecs(specs);
|
||||
definition.setMcpSpecs(mcpSpecs);
|
||||
bundle.setToolInvokers(invokers);
|
||||
}
|
||||
|
||||
@@ -243,16 +260,74 @@ public class AgentDefinitionCompiler {
|
||||
}
|
||||
return pluginItem.toFunction();
|
||||
}
|
||||
throw new BusinessException("不支持的 Agent 工具类型:" + type.name());
|
||||
}
|
||||
|
||||
private McpSpec buildMcpSpec(AgentToolBinding binding) {
|
||||
Mcp mcp = snapshotOrCurrentMcp(binding);
|
||||
if (mcp == null) {
|
||||
throw new BusinessException("绑定 MCP 不存在");
|
||||
}
|
||||
McpTool tool = new McpTool();
|
||||
tool.setMcpId(mcp.getId());
|
||||
tool.setName(binding.getToolName());
|
||||
tool.setDescription(mcp.getDescription());
|
||||
tool.setParameters(new Parameter[0]);
|
||||
return tool;
|
||||
Map.Entry<String, Map<String, Object>> server = firstMcpServer(mcp);
|
||||
Map<String, Object> serverConfig = server.getValue();
|
||||
McpTransportType transportType = parseMcpTransportType(mcp, serverConfig);
|
||||
|
||||
McpSpec spec = new McpSpec();
|
||||
spec.setName(mcpRuntimeName(mcp));
|
||||
spec.setDescription(firstNonBlank(mcp.getDescription(), mcp.getTitle()));
|
||||
spec.setTransportType(transportType);
|
||||
spec.setCommand(resolveMcpInput(stringValue(serverConfig, "command", null)));
|
||||
spec.setArgs(resolveMcpInputs(stringListValue(serverConfig, "args")));
|
||||
spec.setEnv(resolveMcpInputMap(stringMapValue(serverConfig, "env")));
|
||||
spec.setUrl(resolveMcpInput(stringValue(serverConfig, "url", null)));
|
||||
spec.setHeaders(resolveMcpInputMap(stringMapValue(serverConfig, "headers")));
|
||||
spec.setQueryParams(resolveMcpInputMap(stringMapValue(serverConfig, "queryParams")));
|
||||
Duration timeout = durationValue(serverConfig, "timeout");
|
||||
if (timeout != null) {
|
||||
spec.setTimeout(timeout);
|
||||
}
|
||||
Duration initializationTimeout = durationValue(serverConfig, "initializationTimeout");
|
||||
if (initializationTimeout != null) {
|
||||
spec.setInitializationTimeout(initializationTimeout);
|
||||
}
|
||||
spec.setGroupName(mcpRuntimeName(mcp));
|
||||
spec.setApprovalRequired(Boolean.TRUE.equals(mcp.getApprovalRequired()));
|
||||
spec.setApprovalRequest(buildMcpApprovalRequest(mcp));
|
||||
spec.setToolNamePrefix(mcpRuntimeToolPrefix(mcp.getId()));
|
||||
spec.getMetadata().put("toolType", AgentToolType.MCP.name());
|
||||
spec.getMetadata().put("mcpId", String.valueOf(mcp.getId()));
|
||||
spec.getMetadata().put("mcpTitle", mcp.getTitle());
|
||||
spec.getMetadata().put("serverName", server.getKey());
|
||||
return spec;
|
||||
}
|
||||
|
||||
private void applyMcpToolBinding(McpSpec spec, AgentToolBinding binding) {
|
||||
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||
spec.setApprovalRequired(true);
|
||||
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||
}
|
||||
}
|
||||
|
||||
private AgentToolApprovalRequest buildMcpApprovalRequest(Mcp mcp) {
|
||||
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||
request.setApprovalPrompt("是否批准执行 MCP 工具:" + firstNonBlank(mcp.getTitle(), mcpRuntimeName(mcp)));
|
||||
Map<String, Object> metadata = new LinkedHashMap<>();
|
||||
metadata.put("toolType", AgentToolType.MCP.name());
|
||||
metadata.put("mcpId", String.valueOf(mcp.getId()));
|
||||
metadata.put("mcpTitle", mcp.getTitle());
|
||||
request.setMetadata(metadata);
|
||||
return request;
|
||||
}
|
||||
|
||||
private AgentToolApprovalRequest buildBindingApprovalRequest(AgentToolBinding binding) {
|
||||
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||
request.setApprovalPrompt(stringValue(binding.getHitlConfigJson(), "prompt", "是否批准执行 MCP 工具"));
|
||||
Map<String, Object> metadata = sanitizedHitlMetadata(binding.getHitlConfigJson());
|
||||
metadata.put("toolType", binding.getToolType());
|
||||
metadata.put("bindingId", binding.getId());
|
||||
metadata.put("targetId", binding.getTargetId());
|
||||
request.setMetadata(metadata);
|
||||
return request;
|
||||
}
|
||||
|
||||
private AgentToolSpec toToolSpec(Tool tool, AgentToolBinding binding) {
|
||||
@@ -477,6 +552,138 @@ public class AgentDefinitionCompiler {
|
||||
return mcpService.getById(binding.getTargetId());
|
||||
}
|
||||
|
||||
private Map.Entry<String, Map<String, Object>> firstMcpServer(Mcp mcp) {
|
||||
Map<String, Object> config = parseMcpConfig(mcp);
|
||||
Map<String, Object> servers = mapValue(config, "mcpServers");
|
||||
if (servers.isEmpty()) {
|
||||
throw new BusinessException("MCP 配置 JSON 中没有找到任何 MCP 服务名称");
|
||||
}
|
||||
Map.Entry<String, Object> first = servers.entrySet().iterator().next();
|
||||
if (!(first.getValue() instanceof Map<?, ?> rawServer)) {
|
||||
throw new BusinessException("MCP 服务配置必须是对象:" + first.getKey());
|
||||
}
|
||||
Map<String, Object> serverConfig = new LinkedHashMap<>();
|
||||
rawServer.forEach((key, value) -> serverConfig.put(String.valueOf(key), value));
|
||||
return Map.entry(first.getKey(), serverConfig);
|
||||
}
|
||||
|
||||
private Map<String, Object> parseMcpConfig(Mcp mcp) {
|
||||
String configJson = mcp == null ? null : mcp.getConfigJson();
|
||||
if (configJson == null || configJson.isBlank()) {
|
||||
throw new BusinessException("MCP 配置 JSON 不能为空");
|
||||
}
|
||||
try {
|
||||
return objectMapper.readValue(configJson, new com.fasterxml.jackson.core.type.TypeReference<>() {});
|
||||
} catch (Exception e) {
|
||||
throw new BusinessException("MCP 配置 JSON 格式错误");
|
||||
}
|
||||
}
|
||||
|
||||
private McpTransportType parseMcpTransportType(Mcp mcp, Map<String, Object> serverConfig) {
|
||||
String transport = firstNonBlank(
|
||||
mcp == null ? null : mcp.getTransportType(),
|
||||
stringValue(serverConfig, "transport", null)
|
||||
);
|
||||
return McpTransportType.from(transport);
|
||||
}
|
||||
|
||||
private String mcpRuntimeName(Mcp mcp) {
|
||||
BigInteger id = mcp == null ? null : mcp.getId();
|
||||
return "mcp_" + safeToolNameSegment(id == null ? "unknown" : String.valueOf(id));
|
||||
}
|
||||
|
||||
private String mcpRuntimeToolPrefix(BigInteger mcpId) {
|
||||
return "mcp_" + safeToolNameSegment(String.valueOf(mcpId)) + "_";
|
||||
}
|
||||
|
||||
private String safeToolNameSegment(String value) {
|
||||
String normalized = String.valueOf(value == null ? "" : value).trim()
|
||||
.replaceAll("[^A-Za-z0-9_-]", "_")
|
||||
.replaceAll("_+", "_");
|
||||
if (normalized.isBlank()) {
|
||||
return "tool";
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
private List<String> stringListValue(Map<String, Object> map, String key) {
|
||||
Object value = map == null ? null : map.get(key);
|
||||
if (value == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
if (value instanceof Collection<?> collection) {
|
||||
List<String> result = new ArrayList<>();
|
||||
for (Object item : collection) {
|
||||
if (item != null) {
|
||||
result.add(String.valueOf(item));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
throw new BusinessException("Agent 配置字段必须是数组:" + key);
|
||||
}
|
||||
|
||||
private Duration durationValue(Map<String, Object> map, String key) {
|
||||
Object value = map == null ? null : map.get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value instanceof Number number) {
|
||||
return Duration.ofSeconds(number.longValue());
|
||||
}
|
||||
String text = String.valueOf(value).trim();
|
||||
if (text.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return Duration.parse(text);
|
||||
} catch (Exception ignored) {
|
||||
try {
|
||||
return Duration.ofSeconds(Long.parseLong(text));
|
||||
} catch (NumberFormatException e) {
|
||||
throw new BusinessException("Agent 配置字段必须是秒数或 Duration:" + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<String> resolveMcpInputs(List<String> values) {
|
||||
if (values == null || values.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
List<String> result = new ArrayList<>(values.size());
|
||||
for (String value : values) {
|
||||
result.add(resolveMcpInput(value));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private Map<String, String> resolveMcpInputMap(Map<String, String> values) {
|
||||
if (values == null || values.isEmpty()) {
|
||||
return new LinkedHashMap<>();
|
||||
}
|
||||
Map<String, String> result = new LinkedHashMap<>();
|
||||
values.forEach((key, value) -> result.put(key, resolveMcpInput(value)));
|
||||
return result;
|
||||
}
|
||||
|
||||
private String resolveMcpInput(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
return value;
|
||||
}
|
||||
Matcher matcher = MCP_INPUT_PATTERN.matcher(value);
|
||||
StringBuffer resolved = new StringBuffer();
|
||||
while (matcher.find()) {
|
||||
String inputKey = matcher.group(1);
|
||||
String resolvedValue = System.getProperty("mcp.input." + inputKey);
|
||||
if (resolvedValue == null || resolvedValue.isBlank()) {
|
||||
throw new BusinessException("MCP 输入变量未解析:" + inputKey);
|
||||
}
|
||||
matcher.appendReplacement(resolved, Matcher.quoteReplacement(resolvedValue));
|
||||
}
|
||||
matcher.appendTail(resolved);
|
||||
return resolved.toString();
|
||||
}
|
||||
|
||||
private DocumentCollection snapshotOrPublishedKnowledge(AgentKnowledgeBinding binding) {
|
||||
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||
DocumentCollection knowledge = objectMapper.convertValue(binding.getResourceSnapshot(), DocumentCollection.class);
|
||||
|
||||
@@ -4,6 +4,8 @@ import com.easyagents.agent.runtime.AgentResumeRequest;
|
||||
import com.easyagents.agent.runtime.AgentRuntime;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||
import com.easyagents.agent.runtime.hitl.AgentResumeToken;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import reactor.core.Disposable;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
@@ -25,6 +27,8 @@ import java.util.function.Consumer;
|
||||
@Component
|
||||
public class AgentRunRegistry {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentRunRegistry.class);
|
||||
|
||||
private final Map<String, AgentRunContext> runs = new ConcurrentHashMap<>();
|
||||
private final Map<String, String> sessionRuns = new ConcurrentHashMap<>();
|
||||
private final Map<String, String> resumeTokenIndex = new ConcurrentHashMap<>();
|
||||
@@ -138,6 +142,7 @@ public class AgentRunRegistry {
|
||||
if (context != null) {
|
||||
sessionRuns.remove(context.sessionId(), requestId);
|
||||
context.releaseLock();
|
||||
context.closeRuntime();
|
||||
}
|
||||
owners.remove(requestId);
|
||||
Set<String> tokens = requestTokens.remove(requestId);
|
||||
@@ -210,6 +215,23 @@ public class AgentRunRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 当前恢复目标是否为草稿试运行。
|
||||
*
|
||||
* @param requestId 请求 ID,可为空
|
||||
* @param resumeToken 恢复令牌
|
||||
* @return true 表示目标为草稿试运行
|
||||
*/
|
||||
public boolean isDraftResumeTarget(String requestId, String resumeToken) {
|
||||
try {
|
||||
String resolvedRequestId = resolveRequestId(requestId, resumeToken);
|
||||
AgentRunContext context = runs.get(resolvedRequestId);
|
||||
return context != null && !context.persistChatlog();
|
||||
} catch (BusinessException ignored) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private void submit(String requestId, String resumeToken, String userId, boolean approved, String reason) {
|
||||
submit(requestId, resumeToken, userId, approved, reason, null);
|
||||
}
|
||||
@@ -430,6 +452,15 @@ public class AgentRunRegistry {
|
||||
return suspended.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* 当前运行是否持久化聊天日志与运行态。
|
||||
*
|
||||
* @return true 表示正式聊天持久化运行
|
||||
*/
|
||||
public boolean persistChatlog() {
|
||||
return persistChatlog;
|
||||
}
|
||||
|
||||
/**
|
||||
* 绑定运行订阅。
|
||||
*
|
||||
@@ -477,6 +508,18 @@ public class AgentRunRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层运行时并释放资源。
|
||||
*/
|
||||
public void closeRuntime() {
|
||||
try {
|
||||
runtime.close();
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Close Agent runtime failed, requestId={}, sessionId={}, message={}",
|
||||
requestId, sessionId, e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过同一个 runtime 恢复挂起运行,事件继续写入原 SSE。
|
||||
*
|
||||
|
||||
@@ -72,10 +72,10 @@ public class AgentRunService {
|
||||
@Resource
|
||||
private AgentChatCapabilityService agentChatCapabilityService;
|
||||
@Resource
|
||||
private AgentSessionStore agentSessionStore;
|
||||
@Resource
|
||||
private EasyFlowAgentSessionStore easyFlowAgentSessionStore;
|
||||
@Resource
|
||||
private AgentSessionStore draftAgentSessionStore;
|
||||
@Resource
|
||||
private AgentRunRegistry agentRunRegistry;
|
||||
@Resource
|
||||
private AgentRunLock agentRunLock;
|
||||
@@ -136,7 +136,7 @@ public class AgentRunService {
|
||||
applyFormalSessionTitle(chatContext, chatRequest.getPrompt(), existingSession);
|
||||
// 执行对话
|
||||
return run(agent, chatRequest.getPrompt(), requestId, traceId, sessionId.toString(),
|
||||
ASSISTANT_CODE, chatContext, true);
|
||||
ASSISTANT_CODE, chatContext, true, easyFlowAgentSessionStore);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -160,7 +160,7 @@ public class AgentRunService {
|
||||
String traceId = UUID.randomUUID().toString();
|
||||
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, chatSessionId, draftRequest.getPrompt(), account, DRAFT_ASSISTANT_CODE);
|
||||
return run(agent, draftRequest.getPrompt(), requestId, traceId, runtimeSessionId,
|
||||
DRAFT_ASSISTANT_CODE, chatContext, false);
|
||||
DRAFT_ASSISTANT_CODE, chatContext, false, draftAgentSessionStore);
|
||||
}
|
||||
|
||||
private SseEmitter run(Agent agent,
|
||||
@@ -170,7 +170,8 @@ public class AgentRunService {
|
||||
String runtimeSessionId,
|
||||
String assistantCode,
|
||||
ChatRuntimeContext chatContext,
|
||||
boolean persistChatlog) {
|
||||
boolean persistChatlog,
|
||||
AgentSessionStore runtimeSessionStore) {
|
||||
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
|
||||
// 获取会话锁
|
||||
AgentRunLock.Handle lockHandle = acquireRunLock(agent, runtimeSessionId);
|
||||
@@ -186,7 +187,7 @@ public class AgentRunService {
|
||||
chatRuntimeManager.recordUserMessage(chatContext, buildUserRuntimeMessage(chatContext, prompt));
|
||||
}
|
||||
threadPoolTaskExecutor.execute(() -> startRuntime(agent, prompt, requestId, traceId, runtimeSessionId,
|
||||
assistantCode, chatContext, chatSseEmitter, persistChatlog, lockHandle));
|
||||
assistantCode, chatContext, chatSseEmitter, persistChatlog, runtimeSessionStore, lockHandle));
|
||||
submitted = true;
|
||||
return chatSseEmitter.getEmitter();
|
||||
} finally {
|
||||
@@ -210,11 +211,12 @@ public class AgentRunService {
|
||||
throw new BusinessException("仅允许清理 Agent 草稿试运行会话");
|
||||
}
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
agentRunRegistry.cancelSession(sessionId, account.getId() == null ? null : account.getId().toString());
|
||||
agentSessionStore.delete(sessionId);
|
||||
if (agentHitlPendingService != null) {
|
||||
agentHitlPendingService.deleteByRuntimeSessionId(sessionId);
|
||||
}
|
||||
clearDraftSessionInternal(sessionId, account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void clearDraftSessionInternal(String sessionId, String userId) {
|
||||
agentRunRegistry.cancelSession(sessionId, userId);
|
||||
draftAgentSessionStore.delete(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -225,9 +227,16 @@ public class AgentRunService {
|
||||
*/
|
||||
public void approve(String requestId, String resumeToken) {
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
String userId = account.getId() == null ? null : account.getId().toString();
|
||||
approveRuntime(requestId, resumeToken, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void approveRuntime(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.approve(requestId, resumeToken, userId);
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.approve(requestId, resumeToken, userId,
|
||||
() -> agentHitlPendingService.approve(resumeToken, account.getId()));
|
||||
() -> agentHitlPendingService.approve(resumeToken, operatorId));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -239,9 +248,16 @@ public class AgentRunService {
|
||||
*/
|
||||
public void reject(String requestId, String resumeToken, String reason) {
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
String userId = account.getId() == null ? null : account.getId().toString();
|
||||
rejectRuntime(requestId, resumeToken, reason, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void rejectRuntime(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason);
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason,
|
||||
() -> agentHitlPendingService.reject(resumeToken, account.getId(), reason));
|
||||
() -> agentHitlPendingService.reject(resumeToken, operatorId, reason));
|
||||
}
|
||||
|
||||
private void startRuntime(Agent agent,
|
||||
@@ -253,6 +269,7 @@ public class AgentRunService {
|
||||
ChatRuntimeContext chatContext,
|
||||
ChatSseEmitter chatSseEmitter,
|
||||
boolean persistChatlog,
|
||||
AgentSessionStore runtimeSessionStore,
|
||||
AgentRunLock.Handle initialLockHandle) {
|
||||
AtomicBoolean finished = new AtomicBoolean(false);
|
||||
StringBuilder answer = new StringBuilder();
|
||||
@@ -262,7 +279,9 @@ public class AgentRunService {
|
||||
assistantAccumulator, finished, persistChatlog);
|
||||
AgentRunLock.Handle lockHandle = initialLockHandle;
|
||||
try {
|
||||
bindAgentSession(agent, runtimeSessionId, chatContext);
|
||||
if (persistChatlog) {
|
||||
bindAgentSession(agent, runtimeSessionId, chatContext);
|
||||
}
|
||||
AgentRuntimeBundle bundle = agentDefinitionCompiler.compile(agent);
|
||||
AgentRuntime runtime = agentRuntimeFactory.create();
|
||||
// 会话初始化请求
|
||||
@@ -272,7 +291,7 @@ public class AgentRunService {
|
||||
request.setRuntimeContext(buildAgentRuntimeContext(chatContext, traceId, runtimeSessionId));
|
||||
request.setToolInvokers(bundle.getToolInvokers());
|
||||
request.setKnowledgeRetrievers(bundle.getKnowledgeRetrievers());
|
||||
request.setSessionStore(agentSessionStore);
|
||||
request.setSessionStore(runtimeSessionStore);
|
||||
request.getMetadata().put("assistantCode", assistantCode);
|
||||
runtime.init(request);
|
||||
// 注册会话运行时管理
|
||||
@@ -346,20 +365,20 @@ public class AgentRunService {
|
||||
return agentRunLock.acquire(agent == null ? null : agent.getId(), runtimeSessionId);
|
||||
}
|
||||
|
||||
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
if (agentRunEventRecorder != null) {
|
||||
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||
if (persistChatlog && agentRunEventRecorder != null) {
|
||||
agentRunEventRecorder.record(requestId, chatContext, event);
|
||||
}
|
||||
}
|
||||
|
||||
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
if (agentHitlPendingService != null) {
|
||||
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||
if (persistChatlog && agentHitlPendingService != null) {
|
||||
agentHitlPendingService.recordApprovalRequired(requestId, chatContext, event);
|
||||
}
|
||||
}
|
||||
|
||||
private void cancelPending(String requestId, String reason) {
|
||||
if (agentHitlPendingService != null) {
|
||||
private void cancelPending(String requestId, String reason, boolean persistChatlog) {
|
||||
if (persistChatlog && agentHitlPendingService != null) {
|
||||
agentHitlPendingService.cancelByRequestId(requestId, reason);
|
||||
}
|
||||
}
|
||||
@@ -397,7 +416,7 @@ public class AgentRunService {
|
||||
}
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消");
|
||||
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消", persistChatlog);
|
||||
if (!persistChatlog) {
|
||||
return;
|
||||
}
|
||||
@@ -420,7 +439,7 @@ public class AgentRunService {
|
||||
if (event == null || event.getEventType() == null) {
|
||||
return;
|
||||
}
|
||||
recordRuntimeEvent(requestId, chatContext, event);
|
||||
recordRuntimeEvent(requestId, chatContext, event, persistChatlog);
|
||||
if (event.getEventType() == AgentRuntimeEventType.MESSAGE_DELTA) {
|
||||
String text = stringPayload(event, "text");
|
||||
if (text != null) {
|
||||
@@ -448,7 +467,7 @@ public class AgentRunService {
|
||||
if (event.getEventType() == AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED) {
|
||||
String resumeToken = stringPayload(event, "resumeToken");
|
||||
agentRunRegistry.registerResumeToken(requestId, resumeToken);
|
||||
recordApprovalRequired(requestId, chatContext, event);
|
||||
recordApprovalRequired(requestId, chatContext, event, persistChatlog);
|
||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.FORM_REQUEST, buildToolHitlPayload(requestId, event))) {
|
||||
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||
}
|
||||
@@ -460,6 +479,7 @@ public class AgentRunService {
|
||||
assistantAccumulator.appendToolCall(
|
||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
||||
stringPayload(event, "toolDisplayName"),
|
||||
firstNonNull(event.getPayload().get("input"), event.getPayload().get("toolInput"))
|
||||
);
|
||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_CALL, buildToolEventPayload(event))) {
|
||||
@@ -473,6 +493,7 @@ public class AgentRunService {
|
||||
assistantAccumulator.appendToolResult(
|
||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
||||
stringPayload(event, "toolDisplayName"),
|
||||
firstNonNull(firstNonNull(event.getPayload().get("output"), event.getPayload().get("result")),
|
||||
event.getPayload().get("text"))
|
||||
);
|
||||
@@ -587,7 +608,7 @@ public class AgentRunService {
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
cancelPending(requestId, safeErrorMessage(error));
|
||||
cancelPending(requestId, safeErrorMessage(error), persistChatlog);
|
||||
Throwable safeError = error == null ? new BusinessException("Agent 运行失败") : error;
|
||||
LOG.error("Agent run failed, requestId={}, message={}, exception={}", requestId,
|
||||
safeError.getMessage(), safeError.toString(), safeError);
|
||||
@@ -621,7 +642,7 @@ public class AgentRunService {
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
String reason = errorMessage(event);
|
||||
cancelPending(requestId, reason);
|
||||
cancelPending(requestId, reason, persistChatlog);
|
||||
LOG.info("Agent run cancelled, requestId={}, reason={}", requestId, reason);
|
||||
if (persistChatlog) {
|
||||
recordPartialAssistantIfPresent(chatContext, answer, assistantAccumulator, reason);
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package tech.easyflow.agent.runtime.session;
|
||||
|
||||
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||
import io.agentscope.core.state.State;
|
||||
import io.agentscope.core.util.JsonUtils;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Agent 草稿试运行 Redis-only session store。
|
||||
*/
|
||||
@Service
|
||||
public class DraftAgentSessionStore implements AgentSessionStore {
|
||||
|
||||
private static final String REDIS_PREFIX = "easyflow:agent:draft-session:";
|
||||
private static final String ENVELOPE_VERSION = "1";
|
||||
private static final String SINGLE_STATES = "singleStates";
|
||||
private static final String LIST_STATES = "listStates";
|
||||
|
||||
private final StringRedisTemplate stringRedisTemplate;
|
||||
private final AgentRuntimeProperties properties;
|
||||
|
||||
/**
|
||||
* 创建草稿试运行 session store。
|
||||
*
|
||||
* @param stringRedisTemplate Redis 模板
|
||||
* @param properties Agent 运行态配置
|
||||
*/
|
||||
public DraftAgentSessionStore(StringRedisTemplate stringRedisTemplate,
|
||||
AgentRuntimeProperties properties) {
|
||||
this.stringRedisTemplate = stringRedisTemplate;
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存单个状态项。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param state 状态值
|
||||
*/
|
||||
@Override
|
||||
public void save(String sessionKey, String name, State state) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || state == null) {
|
||||
return;
|
||||
}
|
||||
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||
singleStates(envelope).put(name, JsonUtils.getJsonCodec().toJson(state));
|
||||
writeCache(sessionKey, envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存状态列表。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param states 状态列表
|
||||
*/
|
||||
@Override
|
||||
public void saveList(String sessionKey, String name, List<? extends State> states) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name)) {
|
||||
return;
|
||||
}
|
||||
List<String> values = new ArrayList<>();
|
||||
if (states != null) {
|
||||
for (State state : states) {
|
||||
values.add(JsonUtils.getJsonCodec().toJson(state));
|
||||
}
|
||||
}
|
||||
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||
listStates(envelope).put(name, values);
|
||||
writeCache(sessionKey, envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取单个状态项。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param type 状态类型
|
||||
* @param <T> 状态类型
|
||||
* @return 可选状态
|
||||
*/
|
||||
@Override
|
||||
public <T extends State> Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || type == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Object json = singleStates(loadEnvelope(sessionKey)).get(name);
|
||||
if (!(json instanceof String text) || text.isBlank()) {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(JsonUtils.getJsonCodec().fromJson(text, type));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取状态列表。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param itemType 状态元素类型
|
||||
* @param <T> 状态元素类型
|
||||
* @return 状态列表
|
||||
*/
|
||||
@Override
|
||||
public <T extends State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || itemType == null) {
|
||||
return List.of();
|
||||
}
|
||||
Object raw = listStates(loadEnvelope(sessionKey)).get(name);
|
||||
if (!(raw instanceof List<?> values) || values.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
List<T> result = new ArrayList<>();
|
||||
for (Object value : values) {
|
||||
if (value instanceof String text && !text.isBlank()) {
|
||||
result.add(JsonUtils.getJsonCodec().fromJson(text, itemType));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断会话键是否存在。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @return 存在时为 true
|
||||
*/
|
||||
@Override
|
||||
public boolean exists(String sessionKey) {
|
||||
return StringUtils.hasText(sessionKey) && readCache(sessionKey) != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除指定会话键下的全部状态。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
*/
|
||||
@Override
|
||||
public void delete(String sessionKey) {
|
||||
if (!StringUtils.hasText(sessionKey)) {
|
||||
return;
|
||||
}
|
||||
deleteCache(sessionKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* 列出当前存储中的会话键。
|
||||
*
|
||||
* <p>草稿 session 使用哈希 Redis key,不维护反向索引,避免为试运行引入额外持久化状态。</p>
|
||||
*
|
||||
* @return 空集合
|
||||
*/
|
||||
@Override
|
||||
public Set<String> listSessionKeys() {
|
||||
return new LinkedHashSet<>();
|
||||
}
|
||||
|
||||
private Map<String, Object> loadEnvelope(String sessionKey) {
|
||||
Map<String, Object> cached = readCache(sessionKey);
|
||||
return cached == null ? emptyEnvelope() : deepCopy(cached);
|
||||
}
|
||||
|
||||
private Map<String, Object> emptyEnvelope() {
|
||||
Map<String, Object> envelope = new LinkedHashMap<>();
|
||||
envelope.put("version", ENVELOPE_VERSION);
|
||||
envelope.put(SINGLE_STATES, new LinkedHashMap<String, Object>());
|
||||
envelope.put(LIST_STATES, new LinkedHashMap<String, Object>());
|
||||
return envelope;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> singleStates(Map<String, Object> envelope) {
|
||||
return (Map<String, Object>) envelope.computeIfAbsent(SINGLE_STATES, key -> new LinkedHashMap<String, Object>());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> listStates(Map<String, Object> envelope) {
|
||||
return (Map<String, Object>) envelope.computeIfAbsent(LIST_STATES, key -> new LinkedHashMap<String, Object>());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> readCache(String sessionKey) {
|
||||
try {
|
||||
String value = stringRedisTemplate.opsForValue().get(cacheKey(sessionKey));
|
||||
if (!StringUtils.hasText(value)) {
|
||||
return null;
|
||||
}
|
||||
return JsonUtils.getJsonCodec().fromJson(value, Map.class);
|
||||
} catch (RuntimeException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void writeCache(String sessionKey, Map<String, Object> envelope) {
|
||||
long seconds = Math.max(1L, properties.getSessionCacheTtl().toSeconds());
|
||||
stringRedisTemplate.opsForValue().set(cacheKey(sessionKey), JsonUtils.getJsonCodec().toJson(envelope),
|
||||
seconds, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
private void deleteCache(String sessionKey) {
|
||||
stringRedisTemplate.delete(cacheKey(sessionKey));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> deepCopy(Map<String, Object> source) {
|
||||
if (source == null || source.isEmpty()) {
|
||||
return emptyEnvelope();
|
||||
}
|
||||
return JsonUtils.getJsonCodec().fromJson(JsonUtils.getJsonCodec().toJson(source), Map.class);
|
||||
}
|
||||
|
||||
private String cacheKey(String sessionKey) {
|
||||
return REDIS_PREFIX + hash(sessionKey);
|
||||
}
|
||||
|
||||
private String hash(String value) {
|
||||
try {
|
||||
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||
byte[] bytes = digest.digest(value.getBytes(StandardCharsets.UTF_8));
|
||||
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
return value.replace(':', '_');
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -276,20 +276,22 @@ public class AgentServiceImpl extends ServiceImpl<AgentMapper, Agent> implements
|
||||
summary.put("bindingId", binding.getId());
|
||||
summary.put("toolType", binding.getToolType());
|
||||
summary.put("targetId", binding.getTargetId());
|
||||
summary.put("toolName", binding.getToolName());
|
||||
summary.put("enabled", Boolean.TRUE.equals(binding.getEnabled()));
|
||||
summary.put("hitlEnabled", Boolean.TRUE.equals(binding.getHitlEnabled()));
|
||||
summary.put("hitlConfigJson", binding.getHitlConfigJson());
|
||||
summary.put("sortNo", binding.getSortNo());
|
||||
if ("WORKFLOW".equalsIgnoreCase(binding.getToolType())) {
|
||||
summary.put("toolName", binding.getToolName());
|
||||
Workflow workflow = workflowService.getById(binding.getTargetId());
|
||||
summary.put("title", workflow == null ? null : workflow.getTitle());
|
||||
} else if ("PLUGIN".equalsIgnoreCase(binding.getToolType())) {
|
||||
summary.put("toolName", binding.getToolName());
|
||||
PluginItem pluginItem = pluginItemService.getById(binding.getTargetId());
|
||||
summary.put("title", pluginItem == null ? null : pluginItem.getName());
|
||||
} else {
|
||||
Mcp mcp = mcpService.getById(binding.getTargetId());
|
||||
summary.put("title", mcp == null ? null : mcp.getTitle());
|
||||
summary.put("tools", mcp == null || mcp.getTools() == null ? List.of() : mcp.getTools());
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package tech.easyflow.agent.runtime;
|
||||
|
||||
import com.easyagents.agent.runtime.AgentDefinition;
|
||||
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.enums.AgentToolType;
|
||||
import tech.easyflow.ai.entity.Mcp;
|
||||
import tech.easyflow.ai.entity.Model;
|
||||
import tech.easyflow.ai.entity.ModelProvider;
|
||||
import tech.easyflow.ai.service.McpService;
|
||||
import tech.easyflow.ai.service.ModelService;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Proxy;
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Agent MCP 运行时定义编译测试。
|
||||
*/
|
||||
public class AgentDefinitionCompilerMcpTest {
|
||||
|
||||
/**
|
||||
* 验证 Agent 绑定 MCP 后会编译为 runtime 原生 MCP 声明,并按整个 MCP 暴露工具。
|
||||
*
|
||||
* @throws Exception 反射注入依赖失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void compileShouldBuildWholeMcpSpecWithDynamicPrefixAndApproval() throws Exception {
|
||||
BigInteger modelId = BigInteger.valueOf(10L);
|
||||
BigInteger mcpId = BigInteger.valueOf(20L);
|
||||
Model model = model(modelId);
|
||||
Mcp mcp = mcp(mcpId);
|
||||
AgentDefinitionCompiler compiler = new AgentDefinitionCompiler();
|
||||
setField(compiler, "objectMapper", new com.fasterxml.jackson.databind.ObjectMapper());
|
||||
setField(compiler, "modelService", modelService(model));
|
||||
setField(compiler, "mcpService", mcpService(mcp));
|
||||
|
||||
Agent agent = agent(modelId, mcpId);
|
||||
|
||||
AgentRuntimeBundle bundle = compiler.compile(agent);
|
||||
AgentDefinition definition = bundle.getDefinition();
|
||||
|
||||
Assert.assertTrue(definition.getToolSpecs().isEmpty());
|
||||
Assert.assertTrue(bundle.getToolInvokers().isEmpty());
|
||||
Assert.assertEquals(1, definition.getMcpSpecs().size());
|
||||
McpSpec spec = definition.getMcpSpecs().get(0);
|
||||
Assert.assertEquals("mcp_20", spec.getName());
|
||||
Assert.assertEquals(McpTransportType.STDIO, spec.getTransportType());
|
||||
Assert.assertEquals("npx", spec.getCommand());
|
||||
Assert.assertEquals(List.of("-y", "@modelcontextprotocol/server-everything"), spec.getArgs());
|
||||
Assert.assertTrue(spec.isApprovalRequired());
|
||||
Assert.assertEquals("mcp_20_", spec.getToolNamePrefix());
|
||||
Assert.assertTrue(spec.getToolAliases().isEmpty());
|
||||
Assert.assertTrue(spec.getEnableTools().isEmpty());
|
||||
Assert.assertEquals(AgentToolType.MCP.name(), spec.getMetadata().get("toolType"));
|
||||
Assert.assertEquals(String.valueOf(mcpId), spec.getMetadata().get("mcpId"));
|
||||
Assert.assertEquals("everything", spec.getMetadata().get("serverName"));
|
||||
Assert.assertTrue(spec.getToolApprovalRequests().isEmpty());
|
||||
Assert.assertEquals("确认调用 MCP 工具?", spec.getApprovalRequest().getApprovalPrompt());
|
||||
}
|
||||
|
||||
private Agent agent(BigInteger modelId, BigInteger mcpId) {
|
||||
AgentToolBinding binding = new AgentToolBinding();
|
||||
binding.setToolType(AgentToolType.MCP.name());
|
||||
binding.setTargetId(mcpId);
|
||||
binding.setEnabled(true);
|
||||
binding.setHitlEnabled(true);
|
||||
binding.setHitlConfigJson(Map.of("prompt", "确认调用 MCP 工具?"));
|
||||
|
||||
Agent agent = new Agent();
|
||||
agent.setId(BigInteger.valueOf(1L));
|
||||
agent.setName("MCP Agent");
|
||||
agent.setModelId(modelId);
|
||||
agent.setToolBindings(List.of(binding));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private Model model(BigInteger modelId) {
|
||||
ModelProvider provider = new ModelProvider();
|
||||
provider.setProviderType("openai");
|
||||
provider.setProviderName("OpenAI");
|
||||
Model model = new Model();
|
||||
model.setId(modelId);
|
||||
model.setModelProvider(provider);
|
||||
model.setModelName("gpt-test");
|
||||
model.setEndpoint("https://example.com");
|
||||
model.setRequestPath("/v1/chat/completions");
|
||||
model.setApiKey("test-key");
|
||||
return model;
|
||||
}
|
||||
|
||||
private Mcp mcp(BigInteger mcpId) {
|
||||
Mcp mcp = new Mcp();
|
||||
mcp.setId(mcpId);
|
||||
mcp.setTitle("Everything");
|
||||
mcp.setDescription("MCP Everything");
|
||||
mcp.setApprovalRequired(true);
|
||||
mcp.setStatus(true);
|
||||
mcp.setConfigJson("""
|
||||
{
|
||||
"mcpServers": {
|
||||
"everything": {
|
||||
"transport": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-everything"]
|
||||
}
|
||||
}
|
||||
}
|
||||
""");
|
||||
return mcp;
|
||||
}
|
||||
|
||||
private ModelService modelService(Model model) {
|
||||
return (ModelService) Proxy.newProxyInstance(
|
||||
ModelService.class.getClassLoader(),
|
||||
new Class<?>[]{ModelService.class},
|
||||
(proxy, method, args) -> "getModelInstance".equals(method.getName()) ? model : defaultValue(method.getReturnType()));
|
||||
}
|
||||
|
||||
private McpService mcpService(Mcp mcp) {
|
||||
return (McpService) Proxy.newProxyInstance(
|
||||
McpService.class.getClassLoader(),
|
||||
new Class<?>[]{McpService.class},
|
||||
(proxy, method, args) -> "getById".equals(method.getName()) ? mcp : defaultValue(method.getReturnType()));
|
||||
}
|
||||
|
||||
private Object defaultValue(Class<?> type) {
|
||||
if (type == boolean.class) {
|
||||
return false;
|
||||
}
|
||||
if (type == int.class || type == long.class || type == short.class || type == byte.class) {
|
||||
return 0;
|
||||
}
|
||||
if (type == double.class || type == float.class) {
|
||||
return 0D;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
field.set(target, value);
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,22 @@
|
||||
package tech.easyflow.agent.runtime;
|
||||
|
||||
import com.easyagents.agent.runtime.AgentInitRequest;
|
||||
import com.easyagents.agent.runtime.AgentRuntime;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEventType;
|
||||
import com.easyagents.agent.runtime.message.AgentKnowledgeReference;
|
||||
import com.easyagents.agent.runtime.message.AgentMessage;
|
||||
import com.easyagents.agent.runtime.message.AgentMessageRole;
|
||||
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||
import com.easyagents.agent.runtime.persistence.session.memory.InMemoryAgentSessionStore;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.agent.entity.AgentHitlPending;
|
||||
import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.common.entity.LoginAccount;
|
||||
@@ -402,14 +409,150 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
|
||||
Exception thrown = Assert.assertThrows(Exception.class, () -> invoke(service, "run",
|
||||
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class,
|
||||
String.class, ChatRuntimeContext.class, boolean.class},
|
||||
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true));
|
||||
String.class, ChatRuntimeContext.class, boolean.class, AgentSessionStore.class},
|
||||
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true,
|
||||
new InMemoryAgentSessionStore()));
|
||||
|
||||
Assert.assertTrue(rootCause(thrown) instanceof BusinessException);
|
||||
Assert.assertEquals(0, chatRuntimeManager.prepareSessionCount);
|
||||
Assert.assertEquals(0, chatRuntimeManager.recordUserMessageCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿运行会使用独立 session store,且不会绑定 MySQL session 元信息。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void startRuntimeShouldUseDraftSessionStoreWithoutBindingMysqlSession() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingAgentDefinitionCompiler compiler = new RecordingAgentDefinitionCompiler();
|
||||
RecordingAgentRuntime runtime = new RecordingAgentRuntime();
|
||||
RecordingAgentRuntimeFactory runtimeFactory = new RecordingAgentRuntimeFactory(runtime);
|
||||
AgentSessionStore draftStore = new InMemoryAgentSessionStore();
|
||||
setField(service, "agentDefinitionCompiler", compiler);
|
||||
setField(service, "agentRuntimeFactory", runtimeFactory);
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
|
||||
Agent agent = new Agent();
|
||||
agent.setId(BigInteger.valueOf(100));
|
||||
invoke(service, "startRuntime",
|
||||
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class, String.class,
|
||||
ChatRuntimeContext.class, ChatSseEmitter.class, boolean.class, AgentSessionStore.class,
|
||||
AgentRunLock.Handle.class},
|
||||
agent, "你好", "request-draft", "trace-draft", "agent-draft-100", "AGENT_DRAFT",
|
||||
chatContext(), new RecordingChatSseEmitter(), false, draftStore, null);
|
||||
|
||||
Assert.assertSame(draftStore, runtime.initRequest.getSessionStore());
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿事件不会写运行事件表,正式事件仍会记录。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void handleRuntimeEventShouldOnlyPersistEventsForFormalChat() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
RecordingAgentRunEventRecorder recorder = new RecordingAgentRunEventRecorder();
|
||||
setField(service, "agentRunEventRecorder", recorder);
|
||||
AgentRuntimeEvent draftEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||
draftEvent.getPayload().put("toolName", "search");
|
||||
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
draftEvent, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||
|
||||
Assert.assertEquals(0, recorder.recordCount);
|
||||
|
||||
AgentRuntimeEvent formalEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||
formalEvent.getPayload().put("toolName", "search");
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
formalEvent, "request-formal", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), true);
|
||||
|
||||
Assert.assertEquals(1, recorder.recordCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿工具审批只注册内存恢复令牌,不写 HITL pending 表。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void draftToolApprovalShouldNotPersistPending() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
AgentRunRegistry registry = new AgentRunRegistry();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
setField(service, "agentRunRegistry", registry);
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
registry.register(runContext("request-draft", "agent-draft-tool", false));
|
||||
AgentRuntimeEvent event = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED);
|
||||
event.getPayload().put("resumeToken", "token-draft");
|
||||
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
event, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||
|
||||
Assert.assertTrue(registry.containsResumeTarget("request-draft", "token-draft"));
|
||||
Assert.assertEquals(0, pendingService.recordApprovalRequiredCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿审批恢复不执行 pending 表消费,正式审批仍执行。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldSkipPendingConsumeOnlyForDraftRun() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
AgentRunRegistry registry = new AgentRunRegistry();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
setField(service, "agentRunRegistry", registry);
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
|
||||
registry.register(runContext("request-draft-approve", "agent-draft-approve", false));
|
||||
registry.registerResumeToken("request-draft-approve", "token-draft-approve");
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
"request-draft-approve", "token-draft-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(0, pendingService.approveCount);
|
||||
|
||||
registry.register(runContext("request-formal-approve", "session-formal-approve", true));
|
||||
registry.registerResumeToken("request-formal-approve", "token-formal-approve");
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
"request-formal-approve", "token-formal-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(1, pendingService.approveCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证清理草稿会话只清草稿 store,不触碰 MySQL pending 清理。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void clearDraftSessionShouldOnlyDeleteDraftStore() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
RecordingAgentSessionStore draftStore = new RecordingAgentSessionStore();
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
setField(service, "draftAgentSessionStore", draftStore);
|
||||
|
||||
invoke(service, "clearDraftSessionInternal",
|
||||
new Class<?>[]{String.class, String.class}, "agent-draft-clear", "1");
|
||||
|
||||
Assert.assertEquals("agent-draft-clear", draftStore.deletedSessionKey);
|
||||
Assert.assertEquals(0, pendingService.deleteByRuntimeSessionIdCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证正式聊天会在会话准备完成后向前端返回真实会话 ID。
|
||||
*
|
||||
@@ -530,6 +673,28 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
ChatRuntimeContext.class, AtomicBoolean.class, boolean.class};
|
||||
}
|
||||
|
||||
private AgentRunRegistry.AgentRunContext runContext(String requestId, String sessionId, boolean persistChatlog) {
|
||||
return new AgentRunRegistry.AgentRunContext(
|
||||
requestId,
|
||||
sessionId,
|
||||
new RecordingAgentRuntime(),
|
||||
new RecordingChatSseEmitter(),
|
||||
chatContext(),
|
||||
new StringBuilder(),
|
||||
new ChatAssistantAccumulator(),
|
||||
new AtomicBoolean(false),
|
||||
persistChatlog,
|
||||
new AgentRunRegistry.RunOwner("agent-1", sessionId, "1"),
|
||||
null,
|
||||
event -> {
|
||||
},
|
||||
error -> {
|
||||
},
|
||||
() -> {
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
private ChatRuntimeContext chatContext() {
|
||||
ChatRuntimeContext context = new ChatRuntimeContext();
|
||||
context.setAssistantId(BigInteger.valueOf(100));
|
||||
@@ -598,6 +763,148 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRuntime implements AgentRuntime {
|
||||
|
||||
private AgentInitRequest initRequest;
|
||||
private int resumeCount;
|
||||
|
||||
@Override
|
||||
public void init(AgentInitRequest request) {
|
||||
initRequest = request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public reactor.core.publisher.Flux<AgentRuntimeEvent> stream(AgentMessage userMessage) {
|
||||
return reactor.core.publisher.Flux.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public reactor.core.publisher.Flux<AgentRuntimeEvent> resume(com.easyagents.agent.runtime.AgentResumeRequest request) {
|
||||
resumeCount++;
|
||||
return reactor.core.publisher.Flux.empty();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRuntimeFactory implements AgentRuntimeFactory {
|
||||
|
||||
private final AgentRuntime runtime;
|
||||
|
||||
private RecordingAgentRuntimeFactory(AgentRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentRuntime create() {
|
||||
return runtime;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentDefinitionCompiler extends AgentDefinitionCompiler {
|
||||
|
||||
@Override
|
||||
public AgentRuntimeBundle compile(Agent agent) {
|
||||
AgentRuntimeBundle bundle = new AgentRuntimeBundle();
|
||||
bundle.setDefinition(new com.easyagents.agent.runtime.AgentDefinition());
|
||||
return bundle;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRunEventRecorder implements AgentRunEventRecorder {
|
||||
|
||||
private int recordCount;
|
||||
|
||||
@Override
|
||||
public void record(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
recordCount++;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentHitlPendingService implements AgentHitlPendingService {
|
||||
|
||||
private int recordApprovalRequiredCount;
|
||||
private int approveCount;
|
||||
private int rejectCount;
|
||||
private int cancelByRequestIdCount;
|
||||
private int deleteByRuntimeSessionIdCount;
|
||||
|
||||
@Override
|
||||
public void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
recordApprovalRequiredCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentHitlPending approve(String resumeToken, BigInteger operatorId) {
|
||||
approveCount++;
|
||||
return new AgentHitlPending();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentHitlPending reject(String resumeToken, BigInteger operatorId, String reason) {
|
||||
rejectCount++;
|
||||
return new AgentHitlPending();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancelByRequestId(String requestId, String reason) {
|
||||
cancelByRequestIdCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByChatSessionId(BigInteger chatSessionId) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByRuntimeSessionId(String runtimeSessionId) {
|
||||
deleteByRuntimeSessionIdCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentHitlPending> expirePending(int limit) {
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentSessionStore implements AgentSessionStore {
|
||||
|
||||
private String deletedSessionKey;
|
||||
|
||||
@Override
|
||||
public void save(String sessionKey, String name, io.agentscope.core.state.State state) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveList(String sessionKey, String name, List<? extends io.agentscope.core.state.State> states) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends io.agentscope.core.state.State> java.util.Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||
return java.util.Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends io.agentscope.core.state.State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean exists(String sessionKey) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(String sessionKey) {
|
||||
deletedSessionKey = sessionKey;
|
||||
}
|
||||
|
||||
@Override
|
||||
public java.util.Set<String> listSessionKeys() {
|
||||
return java.util.Set.of();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 记录 chatlog 写入动作的测试桩。
|
||||
*/
|
||||
|
||||
@@ -37,6 +37,18 @@ public class McpBase extends DateEntity implements Serializable {
|
||||
@Column(comment = "完整MCP配置JSON")
|
||||
private String configJson;
|
||||
|
||||
/**
|
||||
* MCP连接方式
|
||||
*/
|
||||
@Column(comment = "MCP连接方式")
|
||||
private String transportType;
|
||||
|
||||
/**
|
||||
* 是否启用工具调用审批
|
||||
*/
|
||||
@Column(comment = "是否启用工具调用审批")
|
||||
private Boolean approvalRequired;
|
||||
|
||||
/**
|
||||
* 部门ID
|
||||
*/
|
||||
@@ -111,6 +123,22 @@ public class McpBase extends DateEntity implements Serializable {
|
||||
this.configJson = configJson;
|
||||
}
|
||||
|
||||
public String getTransportType() {
|
||||
return transportType;
|
||||
}
|
||||
|
||||
public void setTransportType(String transportType) {
|
||||
this.transportType = transportType;
|
||||
}
|
||||
|
||||
public Boolean getApprovalRequired() {
|
||||
return approvalRequired;
|
||||
}
|
||||
|
||||
public void setApprovalRequired(Boolean approvalRequired) {
|
||||
this.approvalRequired = approvalRequired;
|
||||
}
|
||||
|
||||
public BigInteger getDeptId() {
|
||||
return deptId;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package tech.easyflow.ai.mcp;
|
||||
|
||||
import tech.easyflow.common.util.StringUtil;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* MCP 连接方式。
|
||||
*/
|
||||
public enum McpTransportType {
|
||||
|
||||
/**
|
||||
* 标准输入输出进程通信。
|
||||
*/
|
||||
STDIO("stdio"),
|
||||
|
||||
/**
|
||||
* HTTP SSE 通信。
|
||||
*/
|
||||
SSE("http-sse"),
|
||||
|
||||
/**
|
||||
* Streamable HTTP 通信。
|
||||
*/
|
||||
HTTP("http-stream");
|
||||
|
||||
private final String value;
|
||||
|
||||
/**
|
||||
* 创建 MCP 连接方式。
|
||||
*
|
||||
* @param value 配置值
|
||||
*/
|
||||
McpTransportType(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取配置值。
|
||||
*
|
||||
* @return 配置值
|
||||
*/
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析连接方式。
|
||||
*
|
||||
* @param value 连接方式文本
|
||||
* @return MCP 连接方式
|
||||
*/
|
||||
public static McpTransportType from(String value) {
|
||||
if (StringUtil.noText(value)) {
|
||||
return STDIO;
|
||||
}
|
||||
String normalized = value.trim().toLowerCase(Locale.ROOT);
|
||||
return switch (normalized) {
|
||||
case "stdio" -> STDIO;
|
||||
case "sse", "http-sse" -> SSE;
|
||||
case "http", "http-stream", "streamable-http" -> HTTP;
|
||||
default -> throw new BusinessException("不支持的 MCP 连接方式: " + value);
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tech.easyflow.ai.service;
|
||||
|
||||
import com.easyagents.core.model.chat.tool.Tool;
|
||||
import com.easyagents.mcp.client.McpEnvironmentCheckResult;
|
||||
import com.mybatisflex.core.paginate.Page;
|
||||
import com.mybatisflex.core.service.IService;
|
||||
import tech.easyflow.ai.entity.BotMcp;
|
||||
@@ -30,4 +31,6 @@ public interface McpService extends IService<Mcp> {
|
||||
Mcp getMcpTools(String id);
|
||||
|
||||
Page<Mcp> pageTools(Page<Mcp> mcpPage);
|
||||
|
||||
McpEnvironmentCheckResult checkMcp(String configJson);
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package tech.easyflow.ai.service.impl;
|
||||
import com.easyagents.core.model.chat.tool.Parameter;
|
||||
import com.easyagents.core.model.chat.tool.Tool;
|
||||
import com.easyagents.mcp.client.McpClientManager;
|
||||
import com.easyagents.mcp.client.McpEnvironmentCheckResult;
|
||||
import com.easyagents.mcp.client.McpEnvironmentChecker;
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.mybatisflex.core.paginate.Page;
|
||||
@@ -16,6 +18,7 @@ import tech.easyflow.ai.easyagents.tool.McpTool;
|
||||
import tech.easyflow.ai.entity.BotMcp;
|
||||
import tech.easyflow.ai.entity.Mcp;
|
||||
import tech.easyflow.ai.mapper.McpMapper;
|
||||
import tech.easyflow.ai.mcp.McpTransportType;
|
||||
import tech.easyflow.ai.service.McpService;
|
||||
import tech.easyflow.ai.utils.CommonFiledUtil;
|
||||
import tech.easyflow.common.constant.enums.EnumRes;
|
||||
@@ -37,7 +40,8 @@ import java.util.*;
|
||||
@Service
|
||||
public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpService {
|
||||
private final McpClientManager mcpClientManager = McpClientManager.getInstance();
|
||||
protected Logger Log = LoggerFactory.getLogger(DocumentServiceImpl.class);
|
||||
private final McpEnvironmentChecker mcpEnvironmentChecker = new McpEnvironmentChecker();
|
||||
protected Logger Log = LoggerFactory.getLogger(McpServiceImpl.class);
|
||||
|
||||
@Override
|
||||
public Result<?> saveMcp(Mcp entity) {
|
||||
@@ -49,6 +53,8 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
if (!StringUtil.hasText(serverName)) {
|
||||
return Result.fail("未找到mcp服务名称", serverName);
|
||||
}
|
||||
entity.setTransportType(getFirstMcpTransportType(entity.getConfigJson()));
|
||||
entity.setApprovalRequired(Boolean.TRUE.equals(entity.getApprovalRequired()));
|
||||
try {
|
||||
mcpClientManager.registerFromJson(entity.getConfigJson());
|
||||
} catch (Exception e) {
|
||||
@@ -79,6 +85,8 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
if (!StringUtil.hasText(serverName)) {
|
||||
return Result.fail("未找到mcp服务名称", serverName);
|
||||
}
|
||||
entity.setTransportType(getFirstMcpTransportType(entity.getConfigJson()));
|
||||
entity.setApprovalRequired(Boolean.TRUE.equals(entity.getApprovalRequired()));
|
||||
if (entity.getStatus()) {
|
||||
try {
|
||||
mcpClientManager.registerFromJson(entity.getConfigJson());
|
||||
@@ -121,6 +129,7 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
records.forEach(mcp -> {
|
||||
boolean clientOnline = mcpClientManager.isClientOnline(getFirstMcpServerName(mcp.getConfigJson()));
|
||||
mcp.setClientOnline(clientOnline);
|
||||
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||
}
|
||||
);
|
||||
page.getData().setRecords(records);
|
||||
@@ -130,6 +139,9 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
@Override
|
||||
public Mcp getMcpTools(String id) {
|
||||
Mcp mcp = this.getById(id);
|
||||
if (mcp != null) {
|
||||
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||
}
|
||||
if (mcp != null && mcp.getStatus()) {
|
||||
McpSyncClient mcpClient = getMcpClient(mcp, mcpClientManager);
|
||||
List<McpSchema.Tool> tools = null;
|
||||
@@ -209,9 +221,27 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
return firstServerName.orElse(null);
|
||||
}
|
||||
|
||||
public static String getFirstMcpTransportType(String mcpJson) {
|
||||
JSONObject rootJson = JSON.parseObject(mcpJson);
|
||||
JSONObject mcpServersJson = rootJson.getJSONObject("mcpServers");
|
||||
if (mcpServersJson == null || mcpServersJson.isEmpty()) {
|
||||
return McpTransportType.STDIO.getValue();
|
||||
}
|
||||
Optional<String> firstServerName = mcpServersJson.keySet().stream().findFirst();
|
||||
if (firstServerName.isEmpty()) {
|
||||
return McpTransportType.STDIO.getValue();
|
||||
}
|
||||
JSONObject serverJson = mcpServersJson.getJSONObject(firstServerName.get());
|
||||
if (serverJson == null) {
|
||||
return McpTransportType.STDIO.getValue();
|
||||
}
|
||||
return McpTransportType.from(serverJson.getString("transport")).getValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Page<Mcp> pageTools(Page<Mcp> page) {
|
||||
page.getRecords().forEach(mcp -> {
|
||||
mcp.setTransportType(resolveMcpTransportType(mcp));
|
||||
// mcp 未启用,不查询工具
|
||||
if (!mcp.getStatus()) {
|
||||
return;
|
||||
@@ -235,6 +265,11 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
return page;
|
||||
}
|
||||
|
||||
@Override
|
||||
public McpEnvironmentCheckResult checkMcp(String configJson) {
|
||||
return mcpEnvironmentChecker.check(configJson);
|
||||
}
|
||||
|
||||
private Result<?> validateMcpConfig(Mcp entity) {
|
||||
if (entity == null || !StringUtil.hasText(entity.getConfigJson())) {
|
||||
Log.error("MCP 配置不能为空");
|
||||
@@ -242,4 +277,14 @@ public class McpServiceImpl extends ServiceImpl<McpMapper, Mcp> implements McpS
|
||||
}
|
||||
return Result.ok();
|
||||
}
|
||||
|
||||
private String resolveMcpTransportType(Mcp mcp) {
|
||||
if (mcp == null) {
|
||||
return McpTransportType.STDIO.getValue();
|
||||
}
|
||||
if (StringUtil.hasText(mcp.getTransportType())) {
|
||||
return McpTransportType.from(mcp.getTransportType()).getValue();
|
||||
}
|
||||
return getFirstMcpTransportType(mcp.getConfigJson());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
package tech.easyflow.ai.mcp;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.ai.service.impl.McpServiceImpl;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
/**
|
||||
* {@link McpTransportType} 单元测试。
|
||||
*/
|
||||
public class McpTransportTypeTest {
|
||||
|
||||
/**
|
||||
* 应兼容解析 MCP 配置中常见的连接方式文本。
|
||||
*/
|
||||
@Test
|
||||
public void fromShouldParseSupportedTransportTypes() {
|
||||
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from("stdio"));
|
||||
Assert.assertEquals(McpTransportType.SSE, McpTransportType.from("sse"));
|
||||
Assert.assertEquals(McpTransportType.SSE, McpTransportType.from("http-sse"));
|
||||
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("http"));
|
||||
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("http-stream"));
|
||||
Assert.assertEquals(McpTransportType.HTTP, McpTransportType.from("streamable-http"));
|
||||
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from(null));
|
||||
Assert.assertEquals(McpTransportType.STDIO, McpTransportType.from(" "));
|
||||
}
|
||||
|
||||
/**
|
||||
* 应从 MCP 配置 JSON 中推断首个 server 的连接方式。
|
||||
*/
|
||||
@Test
|
||||
public void getFirstMcpTransportTypeShouldInferFromConfigJson() {
|
||||
Assert.assertEquals("stdio", McpServiceImpl.getFirstMcpTransportType("""
|
||||
{"mcpServers":{"everything":{"command":"npx","args":["-y","@modelcontextprotocol/server-everything"]}}}
|
||||
"""));
|
||||
Assert.assertEquals("http-sse", McpServiceImpl.getFirstMcpTransportType("""
|
||||
{"mcpServers":{"remote":{"transport":"http-sse","url":"http://127.0.0.1:3000/sse"}}}
|
||||
"""));
|
||||
Assert.assertEquals("http-stream", McpServiceImpl.getFirstMcpTransportType("""
|
||||
{"mcpServers":{"remote":{"transport":"http-stream","url":"http://127.0.0.1:3000/mcp"}}}
|
||||
"""));
|
||||
}
|
||||
|
||||
/**
|
||||
* 不支持的连接方式应直接失败,避免保存无法启动的 MCP 配置。
|
||||
*/
|
||||
@Test
|
||||
public void getFirstMcpTransportTypeShouldRejectUnsupportedTransportType() {
|
||||
try {
|
||||
McpServiceImpl.getFirstMcpTransportType("""
|
||||
{"mcpServers":{"remote":{"transport":"websocket","url":"ws://127.0.0.1:3000/mcp"}}}
|
||||
""");
|
||||
Assert.fail("expected BusinessException");
|
||||
} catch (BusinessException exception) {
|
||||
Assert.assertTrue(exception.getMessage().contains("不支持的 MCP 连接方式"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,4 +70,25 @@ public class ChatAssistantAccumulatorTest {
|
||||
Assert.assertEquals(1, secondToolCalls.size());
|
||||
Assert.assertEquals("call-2", secondToolCalls.get(0).get("id"));
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具展示名应进入展示链和 assistant toolCalls,但不覆盖真实工具名。
|
||||
*/
|
||||
@Test
|
||||
@SuppressWarnings("unchecked")
|
||||
public void shouldKeepToolDisplayNameWithoutOverridingToolName() {
|
||||
ChatAssistantAccumulator accumulator = new ChatAssistantAccumulator();
|
||||
accumulator.appendToolCall("call-1", "mcp_123_search", "知识库 MCP - search", "{\"q\":\"java\"}");
|
||||
accumulator.appendToolResult("call-1", "mcp_123_search", "知识库 MCP - search", "{\"ok\":true}");
|
||||
|
||||
Map<String, Object> payload = accumulator.buildPayload(null);
|
||||
List<Map<String, Object>> chains = (List<Map<String, Object>>) payload.get("chains");
|
||||
List<Map<String, Object>> messageChain = (List<Map<String, Object>>) payload.get("messageChain");
|
||||
List<Map<String, Object>> toolCalls = (List<Map<String, Object>>) messageChain.get(0).get("toolCalls");
|
||||
|
||||
Assert.assertEquals("mcp_123_search", chains.get(0).get("name"));
|
||||
Assert.assertEquals("知识库 MCP - search", chains.get(0).get("toolDisplayName"));
|
||||
Assert.assertEquals("mcp_123_search", toolCalls.get(0).get("name"));
|
||||
Assert.assertEquals("知识库 MCP - search", toolCalls.get(0).get("toolDisplayName"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user