feat: 完成 Agent MCP 对接
- 增加 MCP 连接类型、环境检测接口和容器运行环境支持 - 将 Agent 编排改为绑定整体 MCP 并编译为 runtime McpSpec - 优化 MCP 工具展示、审批、草稿试运行和画布回显稳定性
This commit is contained in:
@@ -10,6 +10,8 @@ import com.easyagents.agent.runtime.knowledge.AgentKnowledgeSpec;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryCompressionParameter;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryPolicy;
|
||||
import com.easyagents.agent.runtime.memory.AgentMemoryType;
|
||||
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||
import com.easyagents.agent.runtime.model.AgentGenerationOptions;
|
||||
import com.easyagents.agent.runtime.model.AgentModelProviderType;
|
||||
import com.easyagents.agent.runtime.model.AgentModelSpec;
|
||||
@@ -28,7 +30,6 @@ import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.enums.AgentToolType;
|
||||
import tech.easyflow.ai.easyagents.tool.ChatToolNameHelper;
|
||||
import tech.easyflow.ai.easyagents.tool.McpTool;
|
||||
import tech.easyflow.ai.easyagents.tool.WorkflowTool;
|
||||
import tech.easyflow.ai.easyagentsflow.support.PublishedWorkflowDefinitionIds;
|
||||
import tech.easyflow.ai.entity.*;
|
||||
@@ -40,6 +41,8 @@ import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
import javax.annotation.Resource;
|
||||
import java.math.BigInteger;
|
||||
import java.time.Duration;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
@@ -50,6 +53,7 @@ public class AgentDefinitionCompiler {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentDefinitionCompiler.class);
|
||||
private static final int LOG_TEXT_MAX_LENGTH = 500;
|
||||
private static final Pattern MCP_INPUT_PATTERN = Pattern.compile("\\$\\{input:([A-Za-z0-9_.-]+)}");
|
||||
|
||||
@Resource
|
||||
private ModelService modelService;
|
||||
@@ -210,16 +214,29 @@ public class AgentDefinitionCompiler {
|
||||
}
|
||||
List<AgentToolSpec> specs = new ArrayList<>();
|
||||
Map<String, com.easyagents.agent.runtime.tool.AgentToolInvoker> invokers = new LinkedHashMap<>();
|
||||
List<McpSpec> mcpSpecs = new ArrayList<>();
|
||||
Map<BigInteger, McpSpec> mcpSpecMap = new LinkedHashMap<>();
|
||||
for (AgentToolBinding binding : agent.getToolBindings()) {
|
||||
if (!Boolean.TRUE.equals(binding.getEnabled())) {
|
||||
continue;
|
||||
}
|
||||
AgentToolType type = AgentToolType.from(binding.getToolType());
|
||||
if (type == AgentToolType.MCP) {
|
||||
McpSpec mcpSpec = mcpSpecMap.computeIfAbsent(binding.getTargetId(),
|
||||
ignored -> buildMcpSpec(binding));
|
||||
applyMcpToolBinding(mcpSpec, binding);
|
||||
if (!mcpSpecs.contains(mcpSpec)) {
|
||||
mcpSpecs.add(mcpSpec);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Tool tool = buildTool(binding);
|
||||
AgentToolSpec spec = toToolSpec(tool, binding);
|
||||
specs.add(spec);
|
||||
invokers.put(spec.getName(), (arguments, context) -> invokeTool(tool, arguments));
|
||||
}
|
||||
definition.setToolSpecs(specs);
|
||||
definition.setMcpSpecs(mcpSpecs);
|
||||
bundle.setToolInvokers(invokers);
|
||||
}
|
||||
|
||||
@@ -243,16 +260,74 @@ public class AgentDefinitionCompiler {
|
||||
}
|
||||
return pluginItem.toFunction();
|
||||
}
|
||||
throw new BusinessException("不支持的 Agent 工具类型:" + type.name());
|
||||
}
|
||||
|
||||
private McpSpec buildMcpSpec(AgentToolBinding binding) {
|
||||
Mcp mcp = snapshotOrCurrentMcp(binding);
|
||||
if (mcp == null) {
|
||||
throw new BusinessException("绑定 MCP 不存在");
|
||||
}
|
||||
McpTool tool = new McpTool();
|
||||
tool.setMcpId(mcp.getId());
|
||||
tool.setName(binding.getToolName());
|
||||
tool.setDescription(mcp.getDescription());
|
||||
tool.setParameters(new Parameter[0]);
|
||||
return tool;
|
||||
Map.Entry<String, Map<String, Object>> server = firstMcpServer(mcp);
|
||||
Map<String, Object> serverConfig = server.getValue();
|
||||
McpTransportType transportType = parseMcpTransportType(mcp, serverConfig);
|
||||
|
||||
McpSpec spec = new McpSpec();
|
||||
spec.setName(mcpRuntimeName(mcp));
|
||||
spec.setDescription(firstNonBlank(mcp.getDescription(), mcp.getTitle()));
|
||||
spec.setTransportType(transportType);
|
||||
spec.setCommand(resolveMcpInput(stringValue(serverConfig, "command", null)));
|
||||
spec.setArgs(resolveMcpInputs(stringListValue(serverConfig, "args")));
|
||||
spec.setEnv(resolveMcpInputMap(stringMapValue(serverConfig, "env")));
|
||||
spec.setUrl(resolveMcpInput(stringValue(serverConfig, "url", null)));
|
||||
spec.setHeaders(resolveMcpInputMap(stringMapValue(serverConfig, "headers")));
|
||||
spec.setQueryParams(resolveMcpInputMap(stringMapValue(serverConfig, "queryParams")));
|
||||
Duration timeout = durationValue(serverConfig, "timeout");
|
||||
if (timeout != null) {
|
||||
spec.setTimeout(timeout);
|
||||
}
|
||||
Duration initializationTimeout = durationValue(serverConfig, "initializationTimeout");
|
||||
if (initializationTimeout != null) {
|
||||
spec.setInitializationTimeout(initializationTimeout);
|
||||
}
|
||||
spec.setGroupName(mcpRuntimeName(mcp));
|
||||
spec.setApprovalRequired(Boolean.TRUE.equals(mcp.getApprovalRequired()));
|
||||
spec.setApprovalRequest(buildMcpApprovalRequest(mcp));
|
||||
spec.setToolNamePrefix(mcpRuntimeToolPrefix(mcp.getId()));
|
||||
spec.getMetadata().put("toolType", AgentToolType.MCP.name());
|
||||
spec.getMetadata().put("mcpId", String.valueOf(mcp.getId()));
|
||||
spec.getMetadata().put("mcpTitle", mcp.getTitle());
|
||||
spec.getMetadata().put("serverName", server.getKey());
|
||||
return spec;
|
||||
}
|
||||
|
||||
private void applyMcpToolBinding(McpSpec spec, AgentToolBinding binding) {
|
||||
if (Boolean.TRUE.equals(binding.getHitlEnabled())) {
|
||||
spec.setApprovalRequired(true);
|
||||
spec.setApprovalRequest(buildBindingApprovalRequest(binding));
|
||||
}
|
||||
}
|
||||
|
||||
private AgentToolApprovalRequest buildMcpApprovalRequest(Mcp mcp) {
|
||||
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||
request.setApprovalPrompt("是否批准执行 MCP 工具:" + firstNonBlank(mcp.getTitle(), mcpRuntimeName(mcp)));
|
||||
Map<String, Object> metadata = new LinkedHashMap<>();
|
||||
metadata.put("toolType", AgentToolType.MCP.name());
|
||||
metadata.put("mcpId", String.valueOf(mcp.getId()));
|
||||
metadata.put("mcpTitle", mcp.getTitle());
|
||||
request.setMetadata(metadata);
|
||||
return request;
|
||||
}
|
||||
|
||||
private AgentToolApprovalRequest buildBindingApprovalRequest(AgentToolBinding binding) {
|
||||
AgentToolApprovalRequest request = new AgentToolApprovalRequest();
|
||||
request.setApprovalPrompt(stringValue(binding.getHitlConfigJson(), "prompt", "是否批准执行 MCP 工具"));
|
||||
Map<String, Object> metadata = sanitizedHitlMetadata(binding.getHitlConfigJson());
|
||||
metadata.put("toolType", binding.getToolType());
|
||||
metadata.put("bindingId", binding.getId());
|
||||
metadata.put("targetId", binding.getTargetId());
|
||||
request.setMetadata(metadata);
|
||||
return request;
|
||||
}
|
||||
|
||||
private AgentToolSpec toToolSpec(Tool tool, AgentToolBinding binding) {
|
||||
@@ -477,6 +552,138 @@ public class AgentDefinitionCompiler {
|
||||
return mcpService.getById(binding.getTargetId());
|
||||
}
|
||||
|
||||
private Map.Entry<String, Map<String, Object>> firstMcpServer(Mcp mcp) {
|
||||
Map<String, Object> config = parseMcpConfig(mcp);
|
||||
Map<String, Object> servers = mapValue(config, "mcpServers");
|
||||
if (servers.isEmpty()) {
|
||||
throw new BusinessException("MCP 配置 JSON 中没有找到任何 MCP 服务名称");
|
||||
}
|
||||
Map.Entry<String, Object> first = servers.entrySet().iterator().next();
|
||||
if (!(first.getValue() instanceof Map<?, ?> rawServer)) {
|
||||
throw new BusinessException("MCP 服务配置必须是对象:" + first.getKey());
|
||||
}
|
||||
Map<String, Object> serverConfig = new LinkedHashMap<>();
|
||||
rawServer.forEach((key, value) -> serverConfig.put(String.valueOf(key), value));
|
||||
return Map.entry(first.getKey(), serverConfig);
|
||||
}
|
||||
|
||||
private Map<String, Object> parseMcpConfig(Mcp mcp) {
|
||||
String configJson = mcp == null ? null : mcp.getConfigJson();
|
||||
if (configJson == null || configJson.isBlank()) {
|
||||
throw new BusinessException("MCP 配置 JSON 不能为空");
|
||||
}
|
||||
try {
|
||||
return objectMapper.readValue(configJson, new com.fasterxml.jackson.core.type.TypeReference<>() {});
|
||||
} catch (Exception e) {
|
||||
throw new BusinessException("MCP 配置 JSON 格式错误");
|
||||
}
|
||||
}
|
||||
|
||||
private McpTransportType parseMcpTransportType(Mcp mcp, Map<String, Object> serverConfig) {
|
||||
String transport = firstNonBlank(
|
||||
mcp == null ? null : mcp.getTransportType(),
|
||||
stringValue(serverConfig, "transport", null)
|
||||
);
|
||||
return McpTransportType.from(transport);
|
||||
}
|
||||
|
||||
private String mcpRuntimeName(Mcp mcp) {
|
||||
BigInteger id = mcp == null ? null : mcp.getId();
|
||||
return "mcp_" + safeToolNameSegment(id == null ? "unknown" : String.valueOf(id));
|
||||
}
|
||||
|
||||
private String mcpRuntimeToolPrefix(BigInteger mcpId) {
|
||||
return "mcp_" + safeToolNameSegment(String.valueOf(mcpId)) + "_";
|
||||
}
|
||||
|
||||
private String safeToolNameSegment(String value) {
|
||||
String normalized = String.valueOf(value == null ? "" : value).trim()
|
||||
.replaceAll("[^A-Za-z0-9_-]", "_")
|
||||
.replaceAll("_+", "_");
|
||||
if (normalized.isBlank()) {
|
||||
return "tool";
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
private List<String> stringListValue(Map<String, Object> map, String key) {
|
||||
Object value = map == null ? null : map.get(key);
|
||||
if (value == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
if (value instanceof Collection<?> collection) {
|
||||
List<String> result = new ArrayList<>();
|
||||
for (Object item : collection) {
|
||||
if (item != null) {
|
||||
result.add(String.valueOf(item));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
throw new BusinessException("Agent 配置字段必须是数组:" + key);
|
||||
}
|
||||
|
||||
private Duration durationValue(Map<String, Object> map, String key) {
|
||||
Object value = map == null ? null : map.get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value instanceof Number number) {
|
||||
return Duration.ofSeconds(number.longValue());
|
||||
}
|
||||
String text = String.valueOf(value).trim();
|
||||
if (text.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return Duration.parse(text);
|
||||
} catch (Exception ignored) {
|
||||
try {
|
||||
return Duration.ofSeconds(Long.parseLong(text));
|
||||
} catch (NumberFormatException e) {
|
||||
throw new BusinessException("Agent 配置字段必须是秒数或 Duration:" + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<String> resolveMcpInputs(List<String> values) {
|
||||
if (values == null || values.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
List<String> result = new ArrayList<>(values.size());
|
||||
for (String value : values) {
|
||||
result.add(resolveMcpInput(value));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private Map<String, String> resolveMcpInputMap(Map<String, String> values) {
|
||||
if (values == null || values.isEmpty()) {
|
||||
return new LinkedHashMap<>();
|
||||
}
|
||||
Map<String, String> result = new LinkedHashMap<>();
|
||||
values.forEach((key, value) -> result.put(key, resolveMcpInput(value)));
|
||||
return result;
|
||||
}
|
||||
|
||||
private String resolveMcpInput(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
return value;
|
||||
}
|
||||
Matcher matcher = MCP_INPUT_PATTERN.matcher(value);
|
||||
StringBuffer resolved = new StringBuffer();
|
||||
while (matcher.find()) {
|
||||
String inputKey = matcher.group(1);
|
||||
String resolvedValue = System.getProperty("mcp.input." + inputKey);
|
||||
if (resolvedValue == null || resolvedValue.isBlank()) {
|
||||
throw new BusinessException("MCP 输入变量未解析:" + inputKey);
|
||||
}
|
||||
matcher.appendReplacement(resolved, Matcher.quoteReplacement(resolvedValue));
|
||||
}
|
||||
matcher.appendTail(resolved);
|
||||
return resolved.toString();
|
||||
}
|
||||
|
||||
private DocumentCollection snapshotOrPublishedKnowledge(AgentKnowledgeBinding binding) {
|
||||
if (binding.getResourceSnapshot() != null && !binding.getResourceSnapshot().isEmpty()) {
|
||||
DocumentCollection knowledge = objectMapper.convertValue(binding.getResourceSnapshot(), DocumentCollection.class);
|
||||
|
||||
@@ -4,6 +4,8 @@ import com.easyagents.agent.runtime.AgentResumeRequest;
|
||||
import com.easyagents.agent.runtime.AgentRuntime;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||
import com.easyagents.agent.runtime.hitl.AgentResumeToken;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import reactor.core.Disposable;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
@@ -25,6 +27,8 @@ import java.util.function.Consumer;
|
||||
@Component
|
||||
public class AgentRunRegistry {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentRunRegistry.class);
|
||||
|
||||
private final Map<String, AgentRunContext> runs = new ConcurrentHashMap<>();
|
||||
private final Map<String, String> sessionRuns = new ConcurrentHashMap<>();
|
||||
private final Map<String, String> resumeTokenIndex = new ConcurrentHashMap<>();
|
||||
@@ -138,6 +142,7 @@ public class AgentRunRegistry {
|
||||
if (context != null) {
|
||||
sessionRuns.remove(context.sessionId(), requestId);
|
||||
context.releaseLock();
|
||||
context.closeRuntime();
|
||||
}
|
||||
owners.remove(requestId);
|
||||
Set<String> tokens = requestTokens.remove(requestId);
|
||||
@@ -210,6 +215,23 @@ public class AgentRunRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 当前恢复目标是否为草稿试运行。
|
||||
*
|
||||
* @param requestId 请求 ID,可为空
|
||||
* @param resumeToken 恢复令牌
|
||||
* @return true 表示目标为草稿试运行
|
||||
*/
|
||||
public boolean isDraftResumeTarget(String requestId, String resumeToken) {
|
||||
try {
|
||||
String resolvedRequestId = resolveRequestId(requestId, resumeToken);
|
||||
AgentRunContext context = runs.get(resolvedRequestId);
|
||||
return context != null && !context.persistChatlog();
|
||||
} catch (BusinessException ignored) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private void submit(String requestId, String resumeToken, String userId, boolean approved, String reason) {
|
||||
submit(requestId, resumeToken, userId, approved, reason, null);
|
||||
}
|
||||
@@ -430,6 +452,15 @@ public class AgentRunRegistry {
|
||||
return suspended.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* 当前运行是否持久化聊天日志与运行态。
|
||||
*
|
||||
* @return true 表示正式聊天持久化运行
|
||||
*/
|
||||
public boolean persistChatlog() {
|
||||
return persistChatlog;
|
||||
}
|
||||
|
||||
/**
|
||||
* 绑定运行订阅。
|
||||
*
|
||||
@@ -477,6 +508,18 @@ public class AgentRunRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭底层运行时并释放资源。
|
||||
*/
|
||||
public void closeRuntime() {
|
||||
try {
|
||||
runtime.close();
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Close Agent runtime failed, requestId={}, sessionId={}, message={}",
|
||||
requestId, sessionId, e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过同一个 runtime 恢复挂起运行,事件继续写入原 SSE。
|
||||
*
|
||||
|
||||
@@ -72,10 +72,10 @@ public class AgentRunService {
|
||||
@Resource
|
||||
private AgentChatCapabilityService agentChatCapabilityService;
|
||||
@Resource
|
||||
private AgentSessionStore agentSessionStore;
|
||||
@Resource
|
||||
private EasyFlowAgentSessionStore easyFlowAgentSessionStore;
|
||||
@Resource
|
||||
private AgentSessionStore draftAgentSessionStore;
|
||||
@Resource
|
||||
private AgentRunRegistry agentRunRegistry;
|
||||
@Resource
|
||||
private AgentRunLock agentRunLock;
|
||||
@@ -136,7 +136,7 @@ public class AgentRunService {
|
||||
applyFormalSessionTitle(chatContext, chatRequest.getPrompt(), existingSession);
|
||||
// 执行对话
|
||||
return run(agent, chatRequest.getPrompt(), requestId, traceId, sessionId.toString(),
|
||||
ASSISTANT_CODE, chatContext, true);
|
||||
ASSISTANT_CODE, chatContext, true, easyFlowAgentSessionStore);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -160,7 +160,7 @@ public class AgentRunService {
|
||||
String traceId = UUID.randomUUID().toString();
|
||||
ChatRuntimeContext chatContext = buildChatRuntimeContext(agent, chatSessionId, draftRequest.getPrompt(), account, DRAFT_ASSISTANT_CODE);
|
||||
return run(agent, draftRequest.getPrompt(), requestId, traceId, runtimeSessionId,
|
||||
DRAFT_ASSISTANT_CODE, chatContext, false);
|
||||
DRAFT_ASSISTANT_CODE, chatContext, false, draftAgentSessionStore);
|
||||
}
|
||||
|
||||
private SseEmitter run(Agent agent,
|
||||
@@ -170,7 +170,8 @@ public class AgentRunService {
|
||||
String runtimeSessionId,
|
||||
String assistantCode,
|
||||
ChatRuntimeContext chatContext,
|
||||
boolean persistChatlog) {
|
||||
boolean persistChatlog,
|
||||
AgentSessionStore runtimeSessionStore) {
|
||||
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
|
||||
// 获取会话锁
|
||||
AgentRunLock.Handle lockHandle = acquireRunLock(agent, runtimeSessionId);
|
||||
@@ -186,7 +187,7 @@ public class AgentRunService {
|
||||
chatRuntimeManager.recordUserMessage(chatContext, buildUserRuntimeMessage(chatContext, prompt));
|
||||
}
|
||||
threadPoolTaskExecutor.execute(() -> startRuntime(agent, prompt, requestId, traceId, runtimeSessionId,
|
||||
assistantCode, chatContext, chatSseEmitter, persistChatlog, lockHandle));
|
||||
assistantCode, chatContext, chatSseEmitter, persistChatlog, runtimeSessionStore, lockHandle));
|
||||
submitted = true;
|
||||
return chatSseEmitter.getEmitter();
|
||||
} finally {
|
||||
@@ -210,11 +211,12 @@ public class AgentRunService {
|
||||
throw new BusinessException("仅允许清理 Agent 草稿试运行会话");
|
||||
}
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
agentRunRegistry.cancelSession(sessionId, account.getId() == null ? null : account.getId().toString());
|
||||
agentSessionStore.delete(sessionId);
|
||||
if (agentHitlPendingService != null) {
|
||||
agentHitlPendingService.deleteByRuntimeSessionId(sessionId);
|
||||
}
|
||||
clearDraftSessionInternal(sessionId, account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void clearDraftSessionInternal(String sessionId, String userId) {
|
||||
agentRunRegistry.cancelSession(sessionId, userId);
|
||||
draftAgentSessionStore.delete(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -225,9 +227,16 @@ public class AgentRunService {
|
||||
*/
|
||||
public void approve(String requestId, String resumeToken) {
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
String userId = account.getId() == null ? null : account.getId().toString();
|
||||
approveRuntime(requestId, resumeToken, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void approveRuntime(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.approve(requestId, resumeToken, userId);
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.approve(requestId, resumeToken, userId,
|
||||
() -> agentHitlPendingService.approve(resumeToken, account.getId()));
|
||||
() -> agentHitlPendingService.approve(resumeToken, operatorId));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -239,9 +248,16 @@ public class AgentRunService {
|
||||
*/
|
||||
public void reject(String requestId, String resumeToken, String reason) {
|
||||
LoginAccount account = requireCurrentLoginAccount();
|
||||
String userId = account.getId() == null ? null : account.getId().toString();
|
||||
rejectRuntime(requestId, resumeToken, reason, account.getId(), account.getId() == null ? null : account.getId().toString());
|
||||
}
|
||||
|
||||
private void rejectRuntime(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason);
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason,
|
||||
() -> agentHitlPendingService.reject(resumeToken, account.getId(), reason));
|
||||
() -> agentHitlPendingService.reject(resumeToken, operatorId, reason));
|
||||
}
|
||||
|
||||
private void startRuntime(Agent agent,
|
||||
@@ -253,6 +269,7 @@ public class AgentRunService {
|
||||
ChatRuntimeContext chatContext,
|
||||
ChatSseEmitter chatSseEmitter,
|
||||
boolean persistChatlog,
|
||||
AgentSessionStore runtimeSessionStore,
|
||||
AgentRunLock.Handle initialLockHandle) {
|
||||
AtomicBoolean finished = new AtomicBoolean(false);
|
||||
StringBuilder answer = new StringBuilder();
|
||||
@@ -262,7 +279,9 @@ public class AgentRunService {
|
||||
assistantAccumulator, finished, persistChatlog);
|
||||
AgentRunLock.Handle lockHandle = initialLockHandle;
|
||||
try {
|
||||
bindAgentSession(agent, runtimeSessionId, chatContext);
|
||||
if (persistChatlog) {
|
||||
bindAgentSession(agent, runtimeSessionId, chatContext);
|
||||
}
|
||||
AgentRuntimeBundle bundle = agentDefinitionCompiler.compile(agent);
|
||||
AgentRuntime runtime = agentRuntimeFactory.create();
|
||||
// 会话初始化请求
|
||||
@@ -272,7 +291,7 @@ public class AgentRunService {
|
||||
request.setRuntimeContext(buildAgentRuntimeContext(chatContext, traceId, runtimeSessionId));
|
||||
request.setToolInvokers(bundle.getToolInvokers());
|
||||
request.setKnowledgeRetrievers(bundle.getKnowledgeRetrievers());
|
||||
request.setSessionStore(agentSessionStore);
|
||||
request.setSessionStore(runtimeSessionStore);
|
||||
request.getMetadata().put("assistantCode", assistantCode);
|
||||
runtime.init(request);
|
||||
// 注册会话运行时管理
|
||||
@@ -346,20 +365,20 @@ public class AgentRunService {
|
||||
return agentRunLock.acquire(agent == null ? null : agent.getId(), runtimeSessionId);
|
||||
}
|
||||
|
||||
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
if (agentRunEventRecorder != null) {
|
||||
private void recordRuntimeEvent(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||
if (persistChatlog && agentRunEventRecorder != null) {
|
||||
agentRunEventRecorder.record(requestId, chatContext, event);
|
||||
}
|
||||
}
|
||||
|
||||
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
if (agentHitlPendingService != null) {
|
||||
private void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event, boolean persistChatlog) {
|
||||
if (persistChatlog && agentHitlPendingService != null) {
|
||||
agentHitlPendingService.recordApprovalRequired(requestId, chatContext, event);
|
||||
}
|
||||
}
|
||||
|
||||
private void cancelPending(String requestId, String reason) {
|
||||
if (agentHitlPendingService != null) {
|
||||
private void cancelPending(String requestId, String reason, boolean persistChatlog) {
|
||||
if (persistChatlog && agentHitlPendingService != null) {
|
||||
agentHitlPendingService.cancelByRequestId(requestId, reason);
|
||||
}
|
||||
}
|
||||
@@ -397,7 +416,7 @@ public class AgentRunService {
|
||||
}
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消");
|
||||
cancelPending(requestId, "客户端连接已断开,Agent 运行已取消", persistChatlog);
|
||||
if (!persistChatlog) {
|
||||
return;
|
||||
}
|
||||
@@ -420,7 +439,7 @@ public class AgentRunService {
|
||||
if (event == null || event.getEventType() == null) {
|
||||
return;
|
||||
}
|
||||
recordRuntimeEvent(requestId, chatContext, event);
|
||||
recordRuntimeEvent(requestId, chatContext, event, persistChatlog);
|
||||
if (event.getEventType() == AgentRuntimeEventType.MESSAGE_DELTA) {
|
||||
String text = stringPayload(event, "text");
|
||||
if (text != null) {
|
||||
@@ -448,7 +467,7 @@ public class AgentRunService {
|
||||
if (event.getEventType() == AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED) {
|
||||
String resumeToken = stringPayload(event, "resumeToken");
|
||||
agentRunRegistry.registerResumeToken(requestId, resumeToken);
|
||||
recordApprovalRequired(requestId, chatContext, event);
|
||||
recordApprovalRequired(requestId, chatContext, event, persistChatlog);
|
||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.FORM_REQUEST, buildToolHitlPayload(requestId, event))) {
|
||||
cancelDisconnectedRun(requestId, chatContext, answer, assistantAccumulator, finished, persistChatlog);
|
||||
}
|
||||
@@ -460,6 +479,7 @@ public class AgentRunService {
|
||||
assistantAccumulator.appendToolCall(
|
||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
||||
stringPayload(event, "toolDisplayName"),
|
||||
firstNonNull(event.getPayload().get("input"), event.getPayload().get("toolInput"))
|
||||
);
|
||||
if (!sendEnvelope(chatSseEmitter, ChatDomain.TOOL, ChatType.TOOL_CALL, buildToolEventPayload(event))) {
|
||||
@@ -473,6 +493,7 @@ public class AgentRunService {
|
||||
assistantAccumulator.appendToolResult(
|
||||
firstText(event.getToolCallId(), stringPayload(event, "toolCallId")),
|
||||
firstText(stringPayload(event, "toolName"), stringPayload(event, "name")),
|
||||
stringPayload(event, "toolDisplayName"),
|
||||
firstNonNull(firstNonNull(event.getPayload().get("output"), event.getPayload().get("result")),
|
||||
event.getPayload().get("text"))
|
||||
);
|
||||
@@ -587,7 +608,7 @@ public class AgentRunService {
|
||||
return;
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
cancelPending(requestId, safeErrorMessage(error));
|
||||
cancelPending(requestId, safeErrorMessage(error), persistChatlog);
|
||||
Throwable safeError = error == null ? new BusinessException("Agent 运行失败") : error;
|
||||
LOG.error("Agent run failed, requestId={}, message={}, exception={}", requestId,
|
||||
safeError.getMessage(), safeError.toString(), safeError);
|
||||
@@ -621,7 +642,7 @@ public class AgentRunService {
|
||||
}
|
||||
agentRunRegistry.remove(requestId);
|
||||
String reason = errorMessage(event);
|
||||
cancelPending(requestId, reason);
|
||||
cancelPending(requestId, reason, persistChatlog);
|
||||
LOG.info("Agent run cancelled, requestId={}, reason={}", requestId, reason);
|
||||
if (persistChatlog) {
|
||||
recordPartialAssistantIfPresent(chatContext, answer, assistantAccumulator, reason);
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package tech.easyflow.agent.runtime.session;
|
||||
|
||||
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||
import io.agentscope.core.state.State;
|
||||
import io.agentscope.core.util.JsonUtils;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.security.MessageDigest;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* Agent 草稿试运行 Redis-only session store。
|
||||
*/
|
||||
@Service
|
||||
public class DraftAgentSessionStore implements AgentSessionStore {
|
||||
|
||||
private static final String REDIS_PREFIX = "easyflow:agent:draft-session:";
|
||||
private static final String ENVELOPE_VERSION = "1";
|
||||
private static final String SINGLE_STATES = "singleStates";
|
||||
private static final String LIST_STATES = "listStates";
|
||||
|
||||
private final StringRedisTemplate stringRedisTemplate;
|
||||
private final AgentRuntimeProperties properties;
|
||||
|
||||
/**
|
||||
* 创建草稿试运行 session store。
|
||||
*
|
||||
* @param stringRedisTemplate Redis 模板
|
||||
* @param properties Agent 运行态配置
|
||||
*/
|
||||
public DraftAgentSessionStore(StringRedisTemplate stringRedisTemplate,
|
||||
AgentRuntimeProperties properties) {
|
||||
this.stringRedisTemplate = stringRedisTemplate;
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存单个状态项。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param state 状态值
|
||||
*/
|
||||
@Override
|
||||
public void save(String sessionKey, String name, State state) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || state == null) {
|
||||
return;
|
||||
}
|
||||
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||
singleStates(envelope).put(name, JsonUtils.getJsonCodec().toJson(state));
|
||||
writeCache(sessionKey, envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存状态列表。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param states 状态列表
|
||||
*/
|
||||
@Override
|
||||
public void saveList(String sessionKey, String name, List<? extends State> states) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name)) {
|
||||
return;
|
||||
}
|
||||
List<String> values = new ArrayList<>();
|
||||
if (states != null) {
|
||||
for (State state : states) {
|
||||
values.add(JsonUtils.getJsonCodec().toJson(state));
|
||||
}
|
||||
}
|
||||
Map<String, Object> envelope = loadEnvelope(sessionKey);
|
||||
listStates(envelope).put(name, values);
|
||||
writeCache(sessionKey, envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取单个状态项。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param type 状态类型
|
||||
* @param <T> 状态类型
|
||||
* @return 可选状态
|
||||
*/
|
||||
@Override
|
||||
public <T extends State> Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || type == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Object json = singleStates(loadEnvelope(sessionKey)).get(name);
|
||||
if (!(json instanceof String text) || text.isBlank()) {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(JsonUtils.getJsonCodec().fromJson(text, type));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取状态列表。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @param name 状态名称
|
||||
* @param itemType 状态元素类型
|
||||
* @param <T> 状态元素类型
|
||||
* @return 状态列表
|
||||
*/
|
||||
@Override
|
||||
public <T extends State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||
if (!StringUtils.hasText(sessionKey) || !StringUtils.hasText(name) || itemType == null) {
|
||||
return List.of();
|
||||
}
|
||||
Object raw = listStates(loadEnvelope(sessionKey)).get(name);
|
||||
if (!(raw instanceof List<?> values) || values.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
List<T> result = new ArrayList<>();
|
||||
for (Object value : values) {
|
||||
if (value instanceof String text && !text.isBlank()) {
|
||||
result.add(JsonUtils.getJsonCodec().fromJson(text, itemType));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断会话键是否存在。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
* @return 存在时为 true
|
||||
*/
|
||||
@Override
|
||||
public boolean exists(String sessionKey) {
|
||||
return StringUtils.hasText(sessionKey) && readCache(sessionKey) != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除指定会话键下的全部状态。
|
||||
*
|
||||
* @param sessionKey 会话键
|
||||
*/
|
||||
@Override
|
||||
public void delete(String sessionKey) {
|
||||
if (!StringUtils.hasText(sessionKey)) {
|
||||
return;
|
||||
}
|
||||
deleteCache(sessionKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* 列出当前存储中的会话键。
|
||||
*
|
||||
* <p>草稿 session 使用哈希 Redis key,不维护反向索引,避免为试运行引入额外持久化状态。</p>
|
||||
*
|
||||
* @return 空集合
|
||||
*/
|
||||
@Override
|
||||
public Set<String> listSessionKeys() {
|
||||
return new LinkedHashSet<>();
|
||||
}
|
||||
|
||||
private Map<String, Object> loadEnvelope(String sessionKey) {
|
||||
Map<String, Object> cached = readCache(sessionKey);
|
||||
return cached == null ? emptyEnvelope() : deepCopy(cached);
|
||||
}
|
||||
|
||||
private Map<String, Object> emptyEnvelope() {
|
||||
Map<String, Object> envelope = new LinkedHashMap<>();
|
||||
envelope.put("version", ENVELOPE_VERSION);
|
||||
envelope.put(SINGLE_STATES, new LinkedHashMap<String, Object>());
|
||||
envelope.put(LIST_STATES, new LinkedHashMap<String, Object>());
|
||||
return envelope;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> singleStates(Map<String, Object> envelope) {
|
||||
return (Map<String, Object>) envelope.computeIfAbsent(SINGLE_STATES, key -> new LinkedHashMap<String, Object>());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> listStates(Map<String, Object> envelope) {
|
||||
return (Map<String, Object>) envelope.computeIfAbsent(LIST_STATES, key -> new LinkedHashMap<String, Object>());
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> readCache(String sessionKey) {
|
||||
try {
|
||||
String value = stringRedisTemplate.opsForValue().get(cacheKey(sessionKey));
|
||||
if (!StringUtils.hasText(value)) {
|
||||
return null;
|
||||
}
|
||||
return JsonUtils.getJsonCodec().fromJson(value, Map.class);
|
||||
} catch (RuntimeException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private void writeCache(String sessionKey, Map<String, Object> envelope) {
|
||||
long seconds = Math.max(1L, properties.getSessionCacheTtl().toSeconds());
|
||||
stringRedisTemplate.opsForValue().set(cacheKey(sessionKey), JsonUtils.getJsonCodec().toJson(envelope),
|
||||
seconds, TimeUnit.SECONDS);
|
||||
}
|
||||
|
||||
private void deleteCache(String sessionKey) {
|
||||
stringRedisTemplate.delete(cacheKey(sessionKey));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Map<String, Object> deepCopy(Map<String, Object> source) {
|
||||
if (source == null || source.isEmpty()) {
|
||||
return emptyEnvelope();
|
||||
}
|
||||
return JsonUtils.getJsonCodec().fromJson(JsonUtils.getJsonCodec().toJson(source), Map.class);
|
||||
}
|
||||
|
||||
private String cacheKey(String sessionKey) {
|
||||
return REDIS_PREFIX + hash(sessionKey);
|
||||
}
|
||||
|
||||
private String hash(String value) {
|
||||
try {
|
||||
MessageDigest digest = MessageDigest.getInstance("SHA-256");
|
||||
byte[] bytes = digest.digest(value.getBytes(StandardCharsets.UTF_8));
|
||||
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
return value.replace(':', '_');
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -276,20 +276,22 @@ public class AgentServiceImpl extends ServiceImpl<AgentMapper, Agent> implements
|
||||
summary.put("bindingId", binding.getId());
|
||||
summary.put("toolType", binding.getToolType());
|
||||
summary.put("targetId", binding.getTargetId());
|
||||
summary.put("toolName", binding.getToolName());
|
||||
summary.put("enabled", Boolean.TRUE.equals(binding.getEnabled()));
|
||||
summary.put("hitlEnabled", Boolean.TRUE.equals(binding.getHitlEnabled()));
|
||||
summary.put("hitlConfigJson", binding.getHitlConfigJson());
|
||||
summary.put("sortNo", binding.getSortNo());
|
||||
if ("WORKFLOW".equalsIgnoreCase(binding.getToolType())) {
|
||||
summary.put("toolName", binding.getToolName());
|
||||
Workflow workflow = workflowService.getById(binding.getTargetId());
|
||||
summary.put("title", workflow == null ? null : workflow.getTitle());
|
||||
} else if ("PLUGIN".equalsIgnoreCase(binding.getToolType())) {
|
||||
summary.put("toolName", binding.getToolName());
|
||||
PluginItem pluginItem = pluginItemService.getById(binding.getTargetId());
|
||||
summary.put("title", pluginItem == null ? null : pluginItem.getName());
|
||||
} else {
|
||||
Mcp mcp = mcpService.getById(binding.getTargetId());
|
||||
summary.put("title", mcp == null ? null : mcp.getTitle());
|
||||
summary.put("tools", mcp == null || mcp.getTools() == null ? List.of() : mcp.getTools());
|
||||
}
|
||||
return summary;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
package tech.easyflow.agent.runtime;
|
||||
|
||||
import com.easyagents.agent.runtime.AgentDefinition;
|
||||
import com.easyagents.agent.runtime.mcp.McpSpec;
|
||||
import com.easyagents.agent.runtime.mcp.McpTransportType;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.enums.AgentToolType;
|
||||
import tech.easyflow.ai.entity.Mcp;
|
||||
import tech.easyflow.ai.entity.Model;
|
||||
import tech.easyflow.ai.entity.ModelProvider;
|
||||
import tech.easyflow.ai.service.McpService;
|
||||
import tech.easyflow.ai.service.ModelService;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Proxy;
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Agent MCP 运行时定义编译测试。
|
||||
*/
|
||||
public class AgentDefinitionCompilerMcpTest {
|
||||
|
||||
/**
|
||||
* 验证 Agent 绑定 MCP 后会编译为 runtime 原生 MCP 声明,并按整个 MCP 暴露工具。
|
||||
*
|
||||
* @throws Exception 反射注入依赖失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void compileShouldBuildWholeMcpSpecWithDynamicPrefixAndApproval() throws Exception {
|
||||
BigInteger modelId = BigInteger.valueOf(10L);
|
||||
BigInteger mcpId = BigInteger.valueOf(20L);
|
||||
Model model = model(modelId);
|
||||
Mcp mcp = mcp(mcpId);
|
||||
AgentDefinitionCompiler compiler = new AgentDefinitionCompiler();
|
||||
setField(compiler, "objectMapper", new com.fasterxml.jackson.databind.ObjectMapper());
|
||||
setField(compiler, "modelService", modelService(model));
|
||||
setField(compiler, "mcpService", mcpService(mcp));
|
||||
|
||||
Agent agent = agent(modelId, mcpId);
|
||||
|
||||
AgentRuntimeBundle bundle = compiler.compile(agent);
|
||||
AgentDefinition definition = bundle.getDefinition();
|
||||
|
||||
Assert.assertTrue(definition.getToolSpecs().isEmpty());
|
||||
Assert.assertTrue(bundle.getToolInvokers().isEmpty());
|
||||
Assert.assertEquals(1, definition.getMcpSpecs().size());
|
||||
McpSpec spec = definition.getMcpSpecs().get(0);
|
||||
Assert.assertEquals("mcp_20", spec.getName());
|
||||
Assert.assertEquals(McpTransportType.STDIO, spec.getTransportType());
|
||||
Assert.assertEquals("npx", spec.getCommand());
|
||||
Assert.assertEquals(List.of("-y", "@modelcontextprotocol/server-everything"), spec.getArgs());
|
||||
Assert.assertTrue(spec.isApprovalRequired());
|
||||
Assert.assertEquals("mcp_20_", spec.getToolNamePrefix());
|
||||
Assert.assertTrue(spec.getToolAliases().isEmpty());
|
||||
Assert.assertTrue(spec.getEnableTools().isEmpty());
|
||||
Assert.assertEquals(AgentToolType.MCP.name(), spec.getMetadata().get("toolType"));
|
||||
Assert.assertEquals(String.valueOf(mcpId), spec.getMetadata().get("mcpId"));
|
||||
Assert.assertEquals("everything", spec.getMetadata().get("serverName"));
|
||||
Assert.assertTrue(spec.getToolApprovalRequests().isEmpty());
|
||||
Assert.assertEquals("确认调用 MCP 工具?", spec.getApprovalRequest().getApprovalPrompt());
|
||||
}
|
||||
|
||||
private Agent agent(BigInteger modelId, BigInteger mcpId) {
|
||||
AgentToolBinding binding = new AgentToolBinding();
|
||||
binding.setToolType(AgentToolType.MCP.name());
|
||||
binding.setTargetId(mcpId);
|
||||
binding.setEnabled(true);
|
||||
binding.setHitlEnabled(true);
|
||||
binding.setHitlConfigJson(Map.of("prompt", "确认调用 MCP 工具?"));
|
||||
|
||||
Agent agent = new Agent();
|
||||
agent.setId(BigInteger.valueOf(1L));
|
||||
agent.setName("MCP Agent");
|
||||
agent.setModelId(modelId);
|
||||
agent.setToolBindings(List.of(binding));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private Model model(BigInteger modelId) {
|
||||
ModelProvider provider = new ModelProvider();
|
||||
provider.setProviderType("openai");
|
||||
provider.setProviderName("OpenAI");
|
||||
Model model = new Model();
|
||||
model.setId(modelId);
|
||||
model.setModelProvider(provider);
|
||||
model.setModelName("gpt-test");
|
||||
model.setEndpoint("https://example.com");
|
||||
model.setRequestPath("/v1/chat/completions");
|
||||
model.setApiKey("test-key");
|
||||
return model;
|
||||
}
|
||||
|
||||
private Mcp mcp(BigInteger mcpId) {
|
||||
Mcp mcp = new Mcp();
|
||||
mcp.setId(mcpId);
|
||||
mcp.setTitle("Everything");
|
||||
mcp.setDescription("MCP Everything");
|
||||
mcp.setApprovalRequired(true);
|
||||
mcp.setStatus(true);
|
||||
mcp.setConfigJson("""
|
||||
{
|
||||
"mcpServers": {
|
||||
"everything": {
|
||||
"transport": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-everything"]
|
||||
}
|
||||
}
|
||||
}
|
||||
""");
|
||||
return mcp;
|
||||
}
|
||||
|
||||
private ModelService modelService(Model model) {
|
||||
return (ModelService) Proxy.newProxyInstance(
|
||||
ModelService.class.getClassLoader(),
|
||||
new Class<?>[]{ModelService.class},
|
||||
(proxy, method, args) -> "getModelInstance".equals(method.getName()) ? model : defaultValue(method.getReturnType()));
|
||||
}
|
||||
|
||||
private McpService mcpService(Mcp mcp) {
|
||||
return (McpService) Proxy.newProxyInstance(
|
||||
McpService.class.getClassLoader(),
|
||||
new Class<?>[]{McpService.class},
|
||||
(proxy, method, args) -> "getById".equals(method.getName()) ? mcp : defaultValue(method.getReturnType()));
|
||||
}
|
||||
|
||||
private Object defaultValue(Class<?> type) {
|
||||
if (type == boolean.class) {
|
||||
return false;
|
||||
}
|
||||
if (type == int.class || type == long.class || type == short.class || type == byte.class) {
|
||||
return 0;
|
||||
}
|
||||
if (type == double.class || type == float.class) {
|
||||
return 0D;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
field.set(target, value);
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,22 @@
|
||||
package tech.easyflow.agent.runtime;
|
||||
|
||||
import com.easyagents.agent.runtime.AgentInitRequest;
|
||||
import com.easyagents.agent.runtime.AgentRuntime;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||
import com.easyagents.agent.runtime.event.AgentRuntimeEventType;
|
||||
import com.easyagents.agent.runtime.message.AgentKnowledgeReference;
|
||||
import com.easyagents.agent.runtime.message.AgentMessage;
|
||||
import com.easyagents.agent.runtime.message.AgentMessageRole;
|
||||
import com.easyagents.agent.runtime.persistence.session.AgentSessionStore;
|
||||
import com.easyagents.agent.runtime.persistence.session.memory.InMemoryAgentSessionStore;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.agent.entity.AgentHitlPending;
|
||||
import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.common.entity.LoginAccount;
|
||||
@@ -402,14 +409,150 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
|
||||
Exception thrown = Assert.assertThrows(Exception.class, () -> invoke(service, "run",
|
||||
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class,
|
||||
String.class, ChatRuntimeContext.class, boolean.class},
|
||||
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true));
|
||||
String.class, ChatRuntimeContext.class, boolean.class, AgentSessionStore.class},
|
||||
agent, "你好", "request-lock", "trace-lock", "session-lock", "AGENT", context, true,
|
||||
new InMemoryAgentSessionStore()));
|
||||
|
||||
Assert.assertTrue(rootCause(thrown) instanceof BusinessException);
|
||||
Assert.assertEquals(0, chatRuntimeManager.prepareSessionCount);
|
||||
Assert.assertEquals(0, chatRuntimeManager.recordUserMessageCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿运行会使用独立 session store,且不会绑定 MySQL session 元信息。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void startRuntimeShouldUseDraftSessionStoreWithoutBindingMysqlSession() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingAgentDefinitionCompiler compiler = new RecordingAgentDefinitionCompiler();
|
||||
RecordingAgentRuntime runtime = new RecordingAgentRuntime();
|
||||
RecordingAgentRuntimeFactory runtimeFactory = new RecordingAgentRuntimeFactory(runtime);
|
||||
AgentSessionStore draftStore = new InMemoryAgentSessionStore();
|
||||
setField(service, "agentDefinitionCompiler", compiler);
|
||||
setField(service, "agentRuntimeFactory", runtimeFactory);
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
|
||||
Agent agent = new Agent();
|
||||
agent.setId(BigInteger.valueOf(100));
|
||||
invoke(service, "startRuntime",
|
||||
new Class<?>[]{Agent.class, String.class, String.class, String.class, String.class, String.class,
|
||||
ChatRuntimeContext.class, ChatSseEmitter.class, boolean.class, AgentSessionStore.class,
|
||||
AgentRunLock.Handle.class},
|
||||
agent, "你好", "request-draft", "trace-draft", "agent-draft-100", "AGENT_DRAFT",
|
||||
chatContext(), new RecordingChatSseEmitter(), false, draftStore, null);
|
||||
|
||||
Assert.assertSame(draftStore, runtime.initRequest.getSessionStore());
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿事件不会写运行事件表,正式事件仍会记录。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void handleRuntimeEventShouldOnlyPersistEventsForFormalChat() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
RecordingAgentRunEventRecorder recorder = new RecordingAgentRunEventRecorder();
|
||||
setField(service, "agentRunEventRecorder", recorder);
|
||||
AgentRuntimeEvent draftEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||
draftEvent.getPayload().put("toolName", "search");
|
||||
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
draftEvent, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||
|
||||
Assert.assertEquals(0, recorder.recordCount);
|
||||
|
||||
AgentRuntimeEvent formalEvent = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_CALL);
|
||||
formalEvent.getPayload().put("toolName", "search");
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
formalEvent, "request-formal", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), true);
|
||||
|
||||
Assert.assertEquals(1, recorder.recordCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿工具审批只注册内存恢复令牌,不写 HITL pending 表。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void draftToolApprovalShouldNotPersistPending() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
AgentRunRegistry registry = new AgentRunRegistry();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
setField(service, "agentRunRegistry", registry);
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
registry.register(runContext("request-draft", "agent-draft-tool", false));
|
||||
AgentRuntimeEvent event = AgentRuntimeEvent.of(AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED);
|
||||
event.getPayload().put("resumeToken", "token-draft");
|
||||
|
||||
invoke(service, "handleRuntimeEvent",
|
||||
runtimeEventParameterTypes(),
|
||||
event, "request-draft", new RecordingChatSseEmitter(), new StringBuilder(),
|
||||
new ChatAssistantAccumulator(), chatContext(), new AtomicBoolean(false), false);
|
||||
|
||||
Assert.assertTrue(registry.containsResumeTarget("request-draft", "token-draft"));
|
||||
Assert.assertEquals(0, pendingService.recordApprovalRequiredCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证草稿审批恢复不执行 pending 表消费,正式审批仍执行。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldSkipPendingConsumeOnlyForDraftRun() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
AgentRunRegistry registry = new AgentRunRegistry();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
setField(service, "agentRunRegistry", registry);
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
|
||||
registry.register(runContext("request-draft-approve", "agent-draft-approve", false));
|
||||
registry.registerResumeToken("request-draft-approve", "token-draft-approve");
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
"request-draft-approve", "token-draft-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(0, pendingService.approveCount);
|
||||
|
||||
registry.register(runContext("request-formal-approve", "session-formal-approve", true));
|
||||
registry.registerResumeToken("request-formal-approve", "token-formal-approve");
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
"request-formal-approve", "token-formal-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(1, pendingService.approveCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证清理草稿会话只清草稿 store,不触碰 MySQL pending 清理。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void clearDraftSessionShouldOnlyDeleteDraftStore() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
RecordingAgentSessionStore draftStore = new RecordingAgentSessionStore();
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
setField(service, "draftAgentSessionStore", draftStore);
|
||||
|
||||
invoke(service, "clearDraftSessionInternal",
|
||||
new Class<?>[]{String.class, String.class}, "agent-draft-clear", "1");
|
||||
|
||||
Assert.assertEquals("agent-draft-clear", draftStore.deletedSessionKey);
|
||||
Assert.assertEquals(0, pendingService.deleteByRuntimeSessionIdCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证正式聊天会在会话准备完成后向前端返回真实会话 ID。
|
||||
*
|
||||
@@ -530,6 +673,28 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
ChatRuntimeContext.class, AtomicBoolean.class, boolean.class};
|
||||
}
|
||||
|
||||
private AgentRunRegistry.AgentRunContext runContext(String requestId, String sessionId, boolean persistChatlog) {
|
||||
return new AgentRunRegistry.AgentRunContext(
|
||||
requestId,
|
||||
sessionId,
|
||||
new RecordingAgentRuntime(),
|
||||
new RecordingChatSseEmitter(),
|
||||
chatContext(),
|
||||
new StringBuilder(),
|
||||
new ChatAssistantAccumulator(),
|
||||
new AtomicBoolean(false),
|
||||
persistChatlog,
|
||||
new AgentRunRegistry.RunOwner("agent-1", sessionId, "1"),
|
||||
null,
|
||||
event -> {
|
||||
},
|
||||
error -> {
|
||||
},
|
||||
() -> {
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
private ChatRuntimeContext chatContext() {
|
||||
ChatRuntimeContext context = new ChatRuntimeContext();
|
||||
context.setAssistantId(BigInteger.valueOf(100));
|
||||
@@ -598,6 +763,148 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRuntime implements AgentRuntime {
|
||||
|
||||
private AgentInitRequest initRequest;
|
||||
private int resumeCount;
|
||||
|
||||
@Override
|
||||
public void init(AgentInitRequest request) {
|
||||
initRequest = request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public reactor.core.publisher.Flux<AgentRuntimeEvent> stream(AgentMessage userMessage) {
|
||||
return reactor.core.publisher.Flux.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public reactor.core.publisher.Flux<AgentRuntimeEvent> resume(com.easyagents.agent.runtime.AgentResumeRequest request) {
|
||||
resumeCount++;
|
||||
return reactor.core.publisher.Flux.empty();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRuntimeFactory implements AgentRuntimeFactory {
|
||||
|
||||
private final AgentRuntime runtime;
|
||||
|
||||
private RecordingAgentRuntimeFactory(AgentRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentRuntime create() {
|
||||
return runtime;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentDefinitionCompiler extends AgentDefinitionCompiler {
|
||||
|
||||
@Override
|
||||
public AgentRuntimeBundle compile(Agent agent) {
|
||||
AgentRuntimeBundle bundle = new AgentRuntimeBundle();
|
||||
bundle.setDefinition(new com.easyagents.agent.runtime.AgentDefinition());
|
||||
return bundle;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRunEventRecorder implements AgentRunEventRecorder {
|
||||
|
||||
private int recordCount;
|
||||
|
||||
@Override
|
||||
public void record(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
recordCount++;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentHitlPendingService implements AgentHitlPendingService {
|
||||
|
||||
private int recordApprovalRequiredCount;
|
||||
private int approveCount;
|
||||
private int rejectCount;
|
||||
private int cancelByRequestIdCount;
|
||||
private int deleteByRuntimeSessionIdCount;
|
||||
|
||||
@Override
|
||||
public void recordApprovalRequired(String requestId, ChatRuntimeContext chatContext, AgentRuntimeEvent event) {
|
||||
recordApprovalRequiredCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentHitlPending approve(String resumeToken, BigInteger operatorId) {
|
||||
approveCount++;
|
||||
return new AgentHitlPending();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentHitlPending reject(String resumeToken, BigInteger operatorId, String reason) {
|
||||
rejectCount++;
|
||||
return new AgentHitlPending();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancelByRequestId(String requestId, String reason) {
|
||||
cancelByRequestIdCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByChatSessionId(BigInteger chatSessionId) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByRuntimeSessionId(String runtimeSessionId) {
|
||||
deleteByRuntimeSessionIdCount++;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentHitlPending> expirePending(int limit) {
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentSessionStore implements AgentSessionStore {
|
||||
|
||||
private String deletedSessionKey;
|
||||
|
||||
@Override
|
||||
public void save(String sessionKey, String name, io.agentscope.core.state.State state) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveList(String sessionKey, String name, List<? extends io.agentscope.core.state.State> states) {
|
||||
// 测试桩无需处理。
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends io.agentscope.core.state.State> java.util.Optional<T> get(String sessionKey, String name, Class<T> type) {
|
||||
return java.util.Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends io.agentscope.core.state.State> List<T> getList(String sessionKey, String name, Class<T> itemType) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean exists(String sessionKey) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(String sessionKey) {
|
||||
deletedSessionKey = sessionKey;
|
||||
}
|
||||
|
||||
@Override
|
||||
public java.util.Set<String> listSessionKeys() {
|
||||
return java.util.Set.of();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 记录 chatlog 写入动作的测试桩。
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user