perf: 优化中断续聊表现,被中断的回复仍可以进入上下文中,保证记忆连续性

- 续聊逻辑优化
This commit is contained in:
2026-05-25 11:40:10 +08:00
parent 2b5e701ade
commit 5c7182ac3f
4 changed files with 294 additions and 13 deletions

View File

@@ -7,6 +7,7 @@ import com.easyagents.agent.runtime.model.AgentModelSpec;
import com.easyagents.agent.runtime.persistence.AgentPersistencePolicy;
import com.easyagents.agent.runtime.skill.AgentSkillBoxSpec;
import com.easyagents.agent.runtime.tool.AgentToolSpec;
import com.easyagents.agent.runtime.tool.operate.AgentOperateToolSpec;
import java.util.ArrayList;
import java.util.LinkedHashMap;
@@ -27,6 +28,7 @@ public class AgentDefinition {
private AgentGenerationOptions generationOptions = new AgentGenerationOptions();
private AgentExecutionOptions executionOptions = new AgentExecutionOptions();
private List<AgentToolSpec> toolSpecs = new ArrayList<>();
private List<AgentOperateToolSpec> operateToolSpecs = new ArrayList<>();
private List<AgentKnowledgeSpec> knowledgeSpecs = new ArrayList<>();
private AgentMemoryPolicy memoryPolicy = AgentMemoryPolicy.autoContext();
private AgentPersistencePolicy persistencePolicy = AgentPersistencePolicy.disabled();
@@ -177,6 +179,24 @@ public class AgentDefinition {
this.toolSpecs = toolSpecs == null ? new ArrayList<>() : new ArrayList<>(toolSpecs);
}
/**
* 获取操作类工具定义。
*
* @return 操作类工具定义
*/
public List<AgentOperateToolSpec> getOperateToolSpecs() {
return operateToolSpecs;
}
/**
* 设置操作类工具定义。
*
* @param operateToolSpecs 操作类工具定义
*/
public void setOperateToolSpecs(List<AgentOperateToolSpec> operateToolSpecs) {
this.operateToolSpecs = operateToolSpecs == null ? new ArrayList<>() : new ArrayList<>(operateToolSpecs);
}
/**
* 获取知识库定义。
*

View File

@@ -30,6 +30,15 @@ public class AgentResumeRequest {
*/
private Map<String, Object> metadata = new LinkedHashMap<>();
/**
* 恢复请求是否已由调用方的持久化 pending store 完成校验和一次性消费。
*
* <p>该字段仅供服务端集成层使用。普通调用方不应设置该标记;设置后 runtime 会跳过
* 当前进程内 {@code AgentToolApprovalCoordinator} 的 token 存在性校验,用于服务重启或跨节点后
* 从 AgentScope session 中继续 pending tool。</p>
*/
private boolean trusted;
/**
* 获取恢复令牌。
*
@@ -101,4 +110,22 @@ public class AgentResumeRequest {
public void setMetadata(Map<String, Object> metadata) {
this.metadata = metadata == null ? new LinkedHashMap<>() : metadata;
}
/**
* 返回恢复请求是否已由调用方持久化层校验。
*
* @return 已校验时为 true
*/
public boolean isTrusted() {
return trusted;
}
/**
* 设置恢复请求是否已由调用方持久化层校验。
*
* @param trusted 已校验标记
*/
public void setTrusted(boolean trusted) {
this.trusted = trusted;
}
}

View File

@@ -19,6 +19,7 @@ import com.easyagents.agent.runtime.skill.AgentSkillBinding;
import com.easyagents.agent.runtime.skill.AgentSkillRuntimeContext;
import com.easyagents.agent.runtime.tool.AgentToolInvoker;
import com.easyagents.agent.runtime.tool.AgentToolSpec;
import com.easyagents.agent.runtime.tool.operate.AgentOperateToolAdapter;
import io.agentscope.core.ReActAgent;
import io.agentscope.core.agent.Event;
import io.agentscope.core.agent.EventType;
@@ -55,6 +56,7 @@ public class AgentScopeReActRuntime implements AgentRuntime {
private final AgentScopeMemoryAdapter memoryAdapter;
private final AgentScopeSkillAdapter skillAdapter;
private final AgentScopeMessageAdapter messageAdapter;
private final AgentOperateToolAdapter operateToolAdapter = new AgentOperateToolAdapter();
private final AgentKnowledgeCitationMatcher citationMatcher = new HeuristicKnowledgeCitationMatcher();
private final AtomicBoolean initialized = new AtomicBoolean(false);
private final AtomicBoolean running = new AtomicBoolean(false);
@@ -153,7 +155,9 @@ public class AgentScopeReActRuntime implements AgentRuntime {
}
AgentRuntimeExecutionContext executionContext = createResumeExecutionContext(request);
try {
approvalCoordinator.consume(request);
if (!request.isTrusted()) {
approvalCoordinator.consume(request);
}
} catch (RuntimeException error) {
running.set(false);
throw error;
@@ -246,7 +250,7 @@ public class AgentScopeReActRuntime implements AgentRuntime {
// 输出所有事件统一交给调用方存储事件记录,默认是空实现即不记录。
.doOnNext(event -> executionContext.getConversationRecorder().record(executionContext, event))
// 处理中断请求
.doOnCancel(() -> cancelInternal(executionContext, sideEvents, cancelled))
.doOnCancel(() -> cancelInternal(executionContext, sideEvents, finalText, finalMessage, cancelled))
// 释放运行锁并清掉 turn context。
.doFinally(signalType -> cleanupTurn());
}
@@ -608,18 +612,90 @@ public class AgentScopeReActRuntime implements AgentRuntime {
*
* @param context 本轮运行上下文
* @param sideEvents 旁路事件 sink
* @param finalText 当前已累计的助手文本
* @param finalMessage 当前已捕获的结构化助手消息
* @param cancelled 取消去重标记
*/
private void cancelInternal(AgentRuntimeExecutionContext context,
Sinks.Many<AgentRuntimeEvent> sideEvents,
StringBuilder finalText,
AtomicReference<AgentMessage> finalMessage,
AtomicBoolean cancelled) {
if (!cancelled.compareAndSet(false, true)) {
return;
}
context.setCancelReason("cancelled");
approvalCoordinator.cancelAll(context.getCancelReason());
agent.interrupt();
sideEvents.tryEmitComplete();
try {
approvalCoordinator.cancelAll(context.getCancelReason());
agent.interrupt();
persistPartialAssistantOnCancel(finalText, finalMessage);
} finally {
sideEvents.tryEmitComplete();
}
}
/**
* 将取消前已输出的助手内容补写入 AgentScope memory 并保存 session。
*
* <p>AgentScope 的正常完成路径会自行把最终助手消息写入 memory。取消订阅时不会触发
* 完成路径,因此这里仅在已有非空助手内容时补写一次,确保下一轮对话能拿到中断前上下文。</p>
*
* @param finalText 当前已累计的助手文本
* @param finalMessage 当前已捕获的结构化助手消息
*/
private void persistPartialAssistantOnCancel(StringBuilder finalText,
AtomicReference<AgentMessage> finalMessage) {
AgentMessage partialMessage = partialAssistantMessage(finalText, finalMessage);
if (partialMessage == null) {
return;
}
agent.getMemory().addMessage(messageAdapter.toMsg(partialMessage));
saveSession();
}
/**
* 生成取消时可写入 memory 的助手消息。
*
* @param finalText 当前已累计的助手文本
* @param finalMessage 当前已捕获的结构化助手消息
* @return 非空助手消息,不存在有效内容时返回 null
*/
private AgentMessage partialAssistantMessage(StringBuilder finalText,
AtomicReference<AgentMessage> finalMessage) {
AgentMessage message = finalMessage == null ? null : finalMessage.get();
if (hasContent(message)) {
message.setRole(AgentMessageRole.ASSISTANT);
return message;
}
String text = finalText == null ? "" : finalText.toString();
if (text.isBlank()) {
return null;
}
return AgentMessage.text(AgentMessageRole.ASSISTANT, text);
}
/**
* 判断消息是否有可用于上下文的内容块。
*
* @param message 消息
* @return 存在非空内容块时为 true
*/
private boolean hasContent(AgentMessage message) {
if (message == null || message.getContentBlocks() == null || message.getContentBlocks().isEmpty()) {
return false;
}
for (AgentContentBlock block : message.getContentBlocks()) {
if (block instanceof AgentTextBlock textBlock) {
if (textBlock.getText() != null && !textBlock.getText().isBlank()) {
return true;
}
continue;
}
if (block != null) {
return true;
}
}
return false;
}
/**
@@ -920,6 +996,19 @@ public class AgentScopeReActRuntime implements AgentRuntime {
if (definition.getModelSpec() == null) {
throw new AgentRuntimeException("Agent model spec is required.");
}
validateOperateToolConflicts(definition);
}
private void validateOperateToolConflicts(AgentDefinition definition) {
Set<String> operateToolNames = operateToolAdapter.enabledToolNames(definition.getOperateToolSpecs());
if (operateToolNames.isEmpty()) {
return;
}
for (AgentToolSpec toolSpec : definition.getToolSpecs()) {
if (toolSpec != null && operateToolNames.contains(toolSpec.getName())) {
throw new AgentRuntimeException("Agent operate tool conflicts with existing tool: " + toolSpec.getName());
}
}
}
/**
@@ -952,7 +1041,8 @@ public class AgentScopeReActRuntime implements AgentRuntime {
AgentDefinition definition = context.getAgentDefinition();
Model model = modelFactory.create(definition.getModelSpec(), definition.getGenerationOptions());
Toolkit toolkit = new Toolkit();
Map<String, List<AgentTool>> skillTools = buildToolkit(context, toolkit);
AgentScopeToolkitBuildResult toolkitBuildResult = buildToolkit(context, toolkit);
Map<String, List<AgentTool>> skillTools = toolkitBuildResult.skillTools();
AgentScopeMemoryBuildResult memoryResult = memoryAdapter.createMemoryResult(null, definition.getMemoryPolicy(), model);
Memory memory = memoryResult.getMemory();
Knowledge knowledge = knowledgeAdapter.createAggregateKnowledge(context, turnContextHolder);
@@ -964,7 +1054,8 @@ public class AgentScopeReActRuntime implements AgentRuntime {
if (memory instanceof AutoContextMemory) {
interceptors.add(new AutoContextInterceptor(eventBridge, memoryResult.getAutoContextConfig()));
}
interceptors.add(new ToolHitlInterceptor(eventBridge, approvalCoordinator, definition.getToolSpecs()));
interceptors.add(new ToolHitlInterceptor(eventBridge, approvalCoordinator,
mergeToolSpecs(definition.getToolSpecs(), toolkitBuildResult.operateToolSpecs())));
// 注册旁路事件监听器与主线路干预器。观察器只发旁路事件,不修改 AgentScope HookEvent。
List<AgentRuntimeObserver> observers = new ArrayList<>();
observers.add(new SkillExecutionObserver(eventBridge, skillContext, skillBox));
@@ -1003,11 +1094,11 @@ public class AgentScopeReActRuntime implements AgentRuntime {
* @param toolkit AgentScope Toolkit
* @return 按 Skill ID 分组的工具
*/
private Map<String, List<AgentTool>> buildToolkit(AgentRuntimeExecutionContext context,
private AgentScopeToolkitBuildResult buildToolkit(AgentRuntimeExecutionContext context,
Toolkit toolkit) {
Map<String, List<AgentTool>> skillTools = new LinkedHashMap<>();
if (!context.getAgentDefinition().getExecutionOptions().isToolCallingEnabled()) {
return skillTools;
return new AgentScopeToolkitBuildResult(skillTools, List.of());
}
for (AgentToolSpec toolSpec : context.getAgentDefinition().getToolSpecs()) {
AgentToolInvoker invoker = context.getToolInvokers().get(toolSpec.getName());
@@ -1020,7 +1111,20 @@ public class AgentScopeReActRuntime implements AgentRuntime {
skillTools.computeIfAbsent(skillBinding.getSkillId(), key -> new ArrayList<>()).add(agentTool);
}
}
return skillTools;
List<AgentToolSpec> operateToolSpecs = operateToolAdapter.register(
context.getAgentDefinition().getOperateToolSpecs(), toolkit);
return new AgentScopeToolkitBuildResult(skillTools, operateToolSpecs);
}
private List<AgentToolSpec> mergeToolSpecs(List<AgentToolSpec> toolSpecs, List<AgentToolSpec> operateToolSpecs) {
List<AgentToolSpec> merged = new ArrayList<>();
if (toolSpecs != null) {
merged.addAll(toolSpecs);
}
if (operateToolSpecs != null) {
merged.addAll(operateToolSpecs);
}
return merged;
}
/**
@@ -1057,4 +1161,8 @@ public class AgentScopeReActRuntime implements AgentRuntime {
ReActAgent getAgent() {
return agent;
}
private record AgentScopeToolkitBuildResult(Map<String, List<AgentTool>> skillTools,
List<AgentToolSpec> operateToolSpecs) {
}
}

View File

@@ -24,6 +24,9 @@ import com.easyagents.agent.runtime.skill.AgentSkillRuntimeContext;
import com.easyagents.agent.runtime.skill.AgentSkillSpec;
import com.easyagents.agent.runtime.tool.AgentToolResult;
import com.easyagents.agent.runtime.tool.AgentToolSpec;
import com.easyagents.agent.runtime.tool.operate.AgentOperateToolAdapter;
import com.easyagents.agent.runtime.tool.operate.AgentOperateToolSpec;
import com.easyagents.agent.runtime.tool.operate.AgentOperateToolType;
import io.agentscope.core.ReActAgent;
import io.agentscope.core.hook.*;
import io.agentscope.core.memory.autocontext.AutoContextHook;
@@ -45,9 +48,11 @@ import java.lang.reflect.Field;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BooleanSupplier;
/**
* 测试有状态 AgentScope 运行时。
@@ -181,6 +186,65 @@ public class AgentScopeStatefulRuntimeTest {
Assert.assertTrue(interceptors.stream().anyMatch(ToolHitlInterceptor.class::isInstance));
}
@Test
public void shouldRegisterOperateToolsIntoToolkit() {
AgentInitRequest request = initRequest();
request.getAgentDefinition().setOperateToolSpecs(List.of(
operateToolSpec(AgentOperateToolType.READ_FILE),
operateToolSpec(AgentOperateToolType.WRITE_FILE),
operateToolSpec(AgentOperateToolType.SHELL)));
AgentScopeReActRuntime runtime = fakeRuntime();
runtime.init(request);
Toolkit toolkit = runtime.getAgent().getToolkit();
Assert.assertNotNull(toolkit.getTool(AgentOperateToolAdapter.VIEW_TEXT_FILE_TOOL));
Assert.assertNotNull(toolkit.getTool(AgentOperateToolAdapter.LIST_DIRECTORY_TOOL));
Assert.assertNotNull(toolkit.getTool(AgentOperateToolAdapter.WRITE_TEXT_FILE_TOOL));
Assert.assertNotNull(toolkit.getTool(AgentOperateToolAdapter.INSERT_TEXT_FILE_TOOL));
Assert.assertNotNull(toolkit.getTool(AgentOperateToolAdapter.EXECUTE_SHELL_COMMAND_TOOL));
}
@Test
public void shouldSuspendShellOperateToolWithToolHitlInterceptor() {
AgentInitRequest request = initRequest();
AgentOperateToolSpec shell = operateToolSpec(AgentOperateToolType.SHELL);
shell.setShellAllowedCommands(Set.of());
request.getAgentDefinition().setOperateToolSpecs(List.of(shell));
AgentScopeReActRuntime runtime = runtimeWithModel(List.of(ChatResponse.builder()
.id("shell-call-message")
.content(List.of(ToolUseBlock.builder()
.id("call-shell")
.name(AgentOperateToolAdapter.EXECUTE_SHELL_COMMAND_TOOL)
.input(Map.of("command", "echo hello"))
.build()))
.finishReason("tool_calls")
.build()));
runtime.init(request);
List<AgentRuntimeEvent> events = runtime.stream(AgentMessage.text(AgentMessageRole.USER, "run shell"))
.collectList()
.block(Duration.ofSeconds(5));
Assert.assertNotNull(events);
Assert.assertTrue(events.stream().anyMatch(event -> event.getEventType() == AgentRuntimeEventType.TOOL_APPROVAL_REQUIRED));
Assert.assertTrue(events.stream().anyMatch(event -> event.getEventType() == AgentRuntimeEventType.SUSPENDED));
}
@Test(expected = AgentRuntimeException.class)
public void shouldRejectOperateToolNameConflictWithBusinessTool() {
AgentInitRequest request = initRequest();
AgentToolSpec toolSpec = new AgentToolSpec();
toolSpec.setName(AgentOperateToolAdapter.EXECUTE_SHELL_COMMAND_TOOL);
toolSpec.setDescription("conflict");
request.getAgentDefinition().setToolSpecs(List.of(toolSpec));
request.setToolInvokers(Map.of(AgentOperateToolAdapter.EXECUTE_SHELL_COMMAND_TOOL,
(arguments, context) -> AgentToolResult.success("ok")));
request.getAgentDefinition().setOperateToolSpecs(List.of(operateToolSpec(AgentOperateToolType.SHELL)));
fakeRuntime().init(request);
}
@Test
public void shouldEnablePendingToolRecoveryForRejectedHitlContinuation() {
AgentScopeReActRuntime runtime = fakeRuntime();
@@ -558,6 +622,40 @@ public class AgentScopeStatefulRuntimeTest {
}
}
@Test
public void shouldPersistPartialAssistantMessageWhenStreamIsCancelled() throws Exception {
InMemoryAgentSessionStore sessionStore = new InMemoryAgentSessionStore();
AgentInitRequest request = initRequest();
request.setSessionStore(sessionStore);
AgentScopeReActRuntime runtime = runtimeWithModel(List.of(
ChatResponse.builder()
.id("partial-message")
.content(List.of(TextBlock.builder().text("partial answer").build()))
.finishReason("stop")
.build()),
Duration.ofMillis(500));
runtime.init(request);
CompletableFuture<AgentRuntimeEvent> firstDelta = new CompletableFuture<>();
reactor.core.Disposable disposable = runtime.stream(AgentMessage.text(AgentMessageRole.USER, "first"))
.subscribe(event -> {
if (event.getEventType() == AgentRuntimeEventType.MESSAGE_DELTA) {
firstDelta.complete(event);
}
}, firstDelta::completeExceptionally);
AgentRuntimeEvent delta = firstDelta.get(3, TimeUnit.SECONDS);
Assert.assertEquals("partial answer", delta.getPayload().get("text"));
disposable.dispose();
awaitCondition(() -> sessionStore.exists("session-1"));
AgentScopeReActRuntime restoredRuntime = fakeRuntime();
restoredRuntime.init(request);
Assert.assertTrue(restoredRuntime.getAgent().getMemory().getMessages().stream()
.anyMatch(message -> message.getRole() == MsgRole.ASSISTANT
&& "partial answer".equals(message.getTextContent())));
}
@Test
public void shouldNotDuplicateNormalToolEventsFromMainStream() {
AgentInitRequest request = initRequest();
@@ -886,11 +984,15 @@ public class AgentScopeStatefulRuntimeTest {
}
private AgentScopeReActRuntime runtimeWithModel(List<ChatResponse> responses) {
return runtimeWithModel(responses, Duration.ZERO);
}
private AgentScopeReActRuntime runtimeWithModel(List<ChatResponse> responses, Duration completionDelay) {
AgentScopeModelFactory modelFactory = new AgentScopeModelFactory() {
@Override
public Model create(AgentModelSpec modelSpec,
com.easyagents.agent.runtime.model.AgentGenerationOptions generationOptions) {
return new ScriptedModel(modelSpec == null ? "fake-model" : modelSpec.getModelName(), responses);
return new ScriptedModel(modelSpec == null ? "fake-model" : modelSpec.getModelName(), responses, completionDelay);
}
};
return new AgentScopeReActRuntime(modelFactory, new AgentScopeToolAdapter(),
@@ -902,10 +1004,12 @@ public class AgentScopeStatefulRuntimeTest {
private final String modelName;
private final List<ChatResponse> responses;
private final Duration completionDelay;
private ScriptedModel(String modelName, List<ChatResponse> responses) {
private ScriptedModel(String modelName, List<ChatResponse> responses, Duration completionDelay) {
this.modelName = modelName;
this.responses = responses;
this.completionDelay = completionDelay == null ? Duration.ZERO : completionDelay;
}
@Override
@@ -916,7 +1020,11 @@ public class AgentScopeStatefulRuntimeTest {
List<ChatResponse> selectedResponses = hasToolResult && responses.size() > 1
? responses.subList(1, responses.size())
: responses.subList(0, 1);
return Flux.fromIterable(selectedResponses);
Flux<ChatResponse> responseFlux = Flux.fromIterable(selectedResponses);
if (completionDelay.isZero() || completionDelay.isNegative()) {
return responseFlux;
}
return responseFlux.concatWith(Flux.never()).timeout(completionDelay, Flux.empty());
}
@Override
@@ -968,6 +1076,13 @@ public class AgentScopeStatefulRuntimeTest {
return spec;
}
private AgentOperateToolSpec operateToolSpec(AgentOperateToolType type) {
AgentOperateToolSpec spec = new AgentOperateToolSpec();
spec.setType(type);
spec.setBaseDir(System.getProperty("java.io.tmpdir"));
return spec;
}
private SkillBox skillBox(Toolkit toolkit) {
return new AgentScopeSkillAdapter().createSkillBox(skillBoxSpec(), toolkit,
Map.of("skill-1", List.of(new NoopAgentTool("search"))));
@@ -1022,4 +1137,15 @@ public class AgentScopeStatefulRuntimeTest {
private boolean isRuntimeHook(Hook hook) {
return hook instanceof AgentScopeRuntimeHook;
}
private void awaitCondition(BooleanSupplier condition) throws Exception {
long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(3);
while (System.nanoTime() < deadline) {
if (condition.getAsBoolean()) {
return;
}
Thread.sleep(20L);
}
Assert.fail("Condition was not met in time.");
}
}