fix: 修复循环节点模板上下文与累计输出

- 统一节点模板渲染上下文,补齐 memory、当前参数与环境变量

- 支持循环体读取父循环上一轮输出,并区分空字符串与缺失值

- 补充模板路径、上下文与循环累计场景回归测试
This commit is contained in:
2026-04-18 21:02:59 +08:00
parent 56ee149e7c
commit 8b34b4ec40
11 changed files with 332 additions and 21 deletions

View File

@@ -370,6 +370,38 @@ public class ChainState implements Serializable {
return formatArgsMap;
}
/**
* 构建文本模板渲染所需的上下文列表。
* <p>
* 顺序为:链路 memory -> 当前节点参数/临时参数 -> 环境变量。
* 后面的 map 会覆盖前面的同名 key这样既能兼容直接引用历史节点输出
* 又能保证当前节点显式解析出的参数优先级更高。
*
* @param formatArgs 当前节点参与模板渲染的参数
* @return 模板渲染上下文列表
*/
public List<Map<String, Object>> buildTemplateRootMaps(Map<String, Object> formatArgs) {
return Arrays.asList(getMemory(), formatArgs, getEnvMap());
}
/**
* 构建文本模板渲染所需的合并上下文。
*
* @param formatArgs 当前节点参与模板渲染的参数
* @return 合并后的模板上下文
*/
public Map<String, Object> buildTemplateContextMap(Map<String, Object> formatArgs) {
Map<String, Object> templateContext = new LinkedHashMap<>();
if (memory != null && !memory.isEmpty()) {
templateContext.putAll(memory);
}
if (formatArgs != null && !formatArgs.isEmpty()) {
templateContext.putAll(formatArgs);
}
templateContext.putAll(getEnvMap());
return templateContext;
}
public Map<String, Object> resolveParameters(Node node, List<? extends Parameter> parameters, Map<String, Object> formatArgs, boolean ignoreRequired) {
if (parameters == null || parameters.isEmpty()) {
return Collections.emptyMap();
@@ -381,7 +413,7 @@ public class ChainState implements Serializable {
Object value = null;
if (refType == RefType.FIXED) {
value = TextTemplate.of(parameter.getValue())
.formatToString(Arrays.asList(formatArgs, getEnvMap()));
.formatToString(buildTemplateRootMaps(formatArgs));
} else if (refType == RefType.REF) {
value = this.resolveValue(parameter.getRef());
}

View File

@@ -53,8 +53,8 @@ public class CodeNode extends BaseNode {
}
ChainState chainState = chain.getState();
List<Map<String, Object>> variables = Arrays.asList(chainState.resolveParameters(this), chainState.getEnvMap());
String newCode = TextTemplate.of(code).formatToString(variables);
Map<String, Object> parameterValues = chainState.resolveParameters(this);
String newCode = TextTemplate.of(code).formatToString(chainState.buildTemplateRootMaps(parameterValues));
CodeRuntimeEngine codeRuntimeEngine = CodeRuntimeEngineManager.getInstance().getCodeRuntimeEngine(this.engine);
if (codeRuntimeEngine == null) {

View File

@@ -213,7 +213,8 @@ public class HttpNode extends BaseNode {
public Map<String, Object> doExecute(Chain chain) throws IOException {
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
String newUrl = TextTemplate.of(url).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
String newUrl = TextTemplate.of(url)
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
Request.Builder reqBuilder = new Request.Builder().url(newUrl);
@@ -280,7 +281,8 @@ public class HttpNode extends BaseNode {
private RequestBody getRequestBody(Chain chain, Map<String, Object> formatArgs) {
if ("json".equals(bodyType)) {
String bodyJsonString = TextTemplate.of(bodyJson).formatToString(formatArgs, true);
String bodyJsonString = TextTemplate.of(bodyJson)
.formatToString(chain.getState().buildTemplateContextMap(formatArgs), true);
JSONObject jsonObject = JSON.parseObject(bodyJsonString);
return RequestBody.create(jsonObject.toString(), MediaType.parse("application/json"));
}
@@ -317,7 +319,8 @@ public class HttpNode extends BaseNode {
}
if ("raw".equals(bodyType)) {
String rawBodyString = TextTemplate.of(rawBody).formatToString(Arrays.asList(formatArgs, chain.getState().getEnvMap()));
String rawBodyString = TextTemplate.of(rawBody)
.formatToString(chain.getState().buildTemplateRootMaps(formatArgs));
return RequestBody.create(rawBodyString, null);
}
//none

View File

@@ -72,8 +72,10 @@ public class KnowledgeNode extends BaseNode {
@Override
public Map<String, Object> execute(Chain chain) {
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
String realKeyword = TextTemplate.of(keyword).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
String realLimitString = TextTemplate.of(limit).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
String realKeyword = TextTemplate.of(keyword)
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
String realLimitString = TextTemplate.of(limit)
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
int realLimit = 10;
if (StringUtil.hasText(realLimitString)) {
try {

View File

@@ -94,7 +94,8 @@ public class LlmNode extends BaseNode {
throw new RuntimeException("Can not find user prompt");
}
String userPromptString = TextTemplate.of(userPrompt).formatToString(Arrays.asList(parameterValues, chain.getState().getEnvMap()));
String userPromptString = TextTemplate.of(userPrompt)
.formatToString(chain.getState().buildTemplateRootMaps(parameterValues));
Llm llm = LlmManager.getInstance().getChatModel(this.llmId);
@@ -102,7 +103,8 @@ public class LlmNode extends BaseNode {
throw new RuntimeException("Can not find llm: " + this.llmId);
}
String systemPromptString = TextTemplate.of(this.systemPrompt).formatToString(Arrays.asList(parameterValues, chain.getState().getEnvMap()));
String systemPromptString = TextTemplate.of(this.systemPrompt)
.formatToString(chain.getState().buildTemplateRootMaps(parameterValues));
Llm.MessageInfo messageInfo = new Llm.MessageInfo();
messageInfo.setMessage(userPromptString);

View File

@@ -71,6 +71,9 @@ public class LoopNode extends BaseNode {
return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
.set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
}
// 首次进入循环时,先为循环体暴露空的累计输出,方便子节点在第一轮读取到“空值”而不是缺参异常。
publishLoopProgress(chain, Collections.emptyMap());
}
// 由子节点返回:从堆栈低部获取当前循环上下文
else {
@@ -105,7 +108,10 @@ public class LoopNode extends BaseNode {
// 不是第一次执行,合并结果到 subResult
if (loopContext.currentIndex != 0) {
ChainState subState = chain.getState();
mergeResult(loopContext.subResult, subState);
Map<String, Object> currentOutputs = collectCurrentOutputValues(subState);
mergeResult(loopContext.subResult, currentOutputs);
// 将上一轮最新输出同步到循环节点作用域,供下一轮循环体读取。
publishLoopProgress(chain, currentOutputs);
}
@@ -198,7 +204,8 @@ public class LoopNode extends BaseNode {
* @param toResult 主流程的输出参数
* @param subState 子流程的
*/
private void mergeResult(Map<String, Object> toResult, ChainState subState) {
private Map<String, Object> collectCurrentOutputValues(ChainState subState) {
Map<String, Object> currentOutputs = new LinkedHashMap<>();
List<Parameter> outputDefs = getOutputDefs();
if (outputDefs != null) {
for (Parameter outputDef : outputDefs) {
@@ -213,6 +220,19 @@ public class LoopNode extends BaseNode {
value = outputDef.getValue();
}
currentOutputs.put(outputDef.getName(), value);
}
}
return currentOutputs;
}
private void mergeResult(Map<String, Object> toResult, Map<String, Object> currentOutputs) {
List<Parameter> outputDefs = getOutputDefs();
if (outputDefs != null) {
for (Parameter outputDef : outputDefs) {
Object value = currentOutputs.get(outputDef.getName());
@SuppressWarnings("unchecked") List<Object> existList = (List<Object>) toResult.get(outputDef.getName());
if (existList == null) {
existList = new ArrayList<>();
@@ -224,6 +244,30 @@ public class LoopNode extends BaseNode {
}
/**
* 将循环当前累计值同步到父循环节点作用域。
* <p>
* 循环体内读取 `loopNodeId.outputName` 时,拿到的是上一轮该输出的最新值;
* 循环结束后handleNodeResult 会再把最终聚合结果(列表)覆盖回同名作用域。
*/
private void publishLoopProgress(Chain chain, Map<String, Object> currentOutputs) {
List<Parameter> outputDefs = getOutputDefs();
if (outputDefs == null || outputDefs.isEmpty()) {
return;
}
chain.updateStateSafely(state -> {
ConcurrentHashMap<String, Object> memory = state.getMemory();
for (Parameter outputDef : outputDefs) {
String key = this.id + "." + outputDef.getName();
Object value = currentOutputs.get(outputDef.getName());
memory.put(key, value == null ? "" : value);
}
return EnumSet.of(ChainStateField.MEMORY);
});
}
private String buildLoopStackId() {
return this.getId() + "__loop__context";
}

View File

@@ -63,8 +63,10 @@ public class SearchEngineNode extends BaseNode {
@Override
public Map<String, Object> execute(Chain chain) {
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
String realKeyword = TextTemplate.of(keyword).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
String realLimitString = TextTemplate.of(limit).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
String realKeyword = TextTemplate.of(keyword)
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
String realLimitString = TextTemplate.of(limit)
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
int realLimit = 10;
if (StringUtil.hasText(realLimitString)) {
try {

View File

@@ -122,10 +122,11 @@ public class TextTemplate {
}
// 动态表达式求值
String value = evaluate(token.parseResult, rootMap, escapeForJsonOutput);
EvaluationResult evaluationResult = evaluate(token.parseResult, rootMap, escapeForJsonOutput);
String value = evaluationResult.value;
// 没有兜底且值为空时抛出异常
if (!token.explicitEmptyFallback && value.isEmpty()) {
// 没有兜底且表达式完全未命中时抛出异常
if (!token.explicitEmptyFallback && !evaluationResult.resolved) {
throw new IllegalArgumentException(String.format(
"Missing value for expression: \"%s\"%nTemplate: %s%nProvided parameters:%n%s",
token.rawExpression,
@@ -202,19 +203,19 @@ public class TextTemplate {
/**
* 递归求值表达式(支持多级兜底)
*/
private String evaluate(ParseResult pr, Map<String, Object> root, boolean escapeForJsonOutput) {
if (pr == null) return "";
private EvaluationResult evaluate(ParseResult pr, Map<String, Object> root, boolean escapeForJsonOutput) {
if (pr == null) return EvaluationResult.unresolved();
// 字面量直接返回
if (pr.isLiteral) {
String literal = pr.getUnquotedLiteral();
return escapeForJsonOutput ? escapeJsonString(literal) : literal;
return EvaluationResult.resolved(escapeForJsonOutput ? escapeJsonString(literal) : literal);
}
// 尝试从 JSONPath 取值
Object value = getValueByJsonPath(root, pr.expression, escapeForJsonOutput);
if (value != null) {
return value.toString();
return EvaluationResult.resolved(value.toString());
}
// 若未取到,则尝试 fallback
@@ -361,4 +362,26 @@ public class TextTemplate {
this.explicitEmptyFallback = explicitEmptyFallback;
}
}
/**
* 表达式求值结果。
* resolved 表示表达式已成功命中,即使最终字符串为空,也不应视为缺参。
*/
private static class EvaluationResult {
final boolean resolved;
final String value;
private EvaluationResult(boolean resolved, String value) {
this.resolved = resolved;
this.value = value == null ? "" : value;
}
static EvaluationResult resolved(String value) {
return new EvaluationResult(true, value);
}
static EvaluationResult unresolved() {
return new EvaluationResult(false, "");
}
}
}

View File

@@ -0,0 +1,35 @@
package com.easyagents.flow.core.test;
import com.easyagents.flow.core.chain.ChainState;
import com.easyagents.flow.core.chain.Parameter;
import com.easyagents.flow.core.chain.RefType;
import com.easyagents.flow.core.node.StartNode;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
import java.util.Map;
/**
* 验证固定值模板在运行时可以直接读取链路 memory 中的节点作用域变量。
*/
public class ChainTemplateContextTest {
@Test
public void shouldResolveScopedNumericMemoryKeyInFixedParameterTemplate() {
ChainState state = new ChainState();
state.getMemory().put("node_1.123", "7");
Parameter parameter = new Parameter();
parameter.setName("nextValue");
parameter.setRefType(RefType.FIXED);
parameter.setValue("{{node_1.123}}");
StartNode node = new StartNode();
node.setName("开始节点");
Map<String, Object> result = state.resolveParameters(node, Collections.singletonList(parameter));
Assert.assertEquals("7", result.get("nextValue"));
}
}

View File

@@ -0,0 +1,145 @@
package com.easyagents.flow.core.test;
import com.easyagents.flow.core.chain.*;
import com.easyagents.flow.core.chain.repository.ChainDefinitionRepository;
import com.easyagents.flow.core.chain.repository.InMemoryChainStateRepository;
import com.easyagents.flow.core.chain.repository.InMemoryNodeStateRepository;
import com.easyagents.flow.core.chain.runtime.ChainExecutor;
import com.easyagents.flow.core.node.BaseNode;
import com.easyagents.flow.core.node.EndNode;
import com.easyagents.flow.core.node.LoopNode;
import com.easyagents.flow.core.node.StartNode;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
* 验证循环体内可以读取父循环节点上一轮的最新输出。
*/
public class LoopNodeProgressContextTest {
@Test
public void shouldExposeLatestLoopOutputInsideLoopBody() {
ChainDefinition definition = new ChainDefinition();
definition.setId("loop-progress-test");
StartNode startNode = new StartNode();
startNode.setId("start");
startNode.setName("开始节点");
startNode.setParameters(Collections.singletonList(inputParameter("times")));
LoopNode loopNode = new LoopNode();
loopNode.setId("loop");
loopNode.setName("循环节点");
Parameter loopVar = new Parameter();
loopVar.setName("times");
loopVar.setRef("times");
loopVar.setRefType(RefType.REF);
loopNode.setLoopVar(loopVar);
AccumulatorNode accumulatorNode = new AccumulatorNode("loop");
accumulatorNode.setId("acc");
accumulatorNode.setName("累计节点");
accumulatorNode.setParentId("loop");
Parameter loopOutput = new Parameter();
loopOutput.setName("current");
loopOutput.setRef("acc.current");
loopOutput.setRefType(RefType.REF);
loopNode.setOutputDefs(Collections.singletonList(loopOutput));
EndNode endNode = new EndNode();
endNode.setId("end");
endNode.setName("结束节点");
Parameter result = new Parameter();
result.setName("result");
result.setRef("loop.current");
result.setRefType(RefType.REF);
endNode.setOutputDefs(Collections.singletonList(result));
definition.addNode(startNode);
definition.addNode(loopNode);
definition.addNode(accumulatorNode);
definition.addNode(endNode);
definition.addEdge(edge("e1", "start", "loop"));
definition.addEdge(edge("e2", "loop", "acc"));
definition.addEdge(edge("e3", "loop", "end"));
ChainExecutor executor = new ChainExecutor(new FixedDefinitionRepository(definition),
new InMemoryChainStateRepository(),
new InMemoryNodeStateRepository());
Map<String, Object> variables = new HashMap<>();
variables.put("times", 2);
Map<String, Object> resultMap = executor.execute("loop-progress-test", variables);
Assert.assertEquals(java.util.Arrays.asList("1", "2"), resultMap.get("result"));
}
private static Parameter inputParameter(String name) {
Parameter parameter = new Parameter();
parameter.setName(name);
parameter.setRefType(RefType.INPUT);
parameter.setRequired(true);
return parameter;
}
private static Edge edge(String id, String source, String target) {
Edge edge = new Edge();
edge.setId(id);
edge.setSource(source);
edge.setTarget(target);
return edge;
}
private static class FixedDefinitionRepository implements ChainDefinitionRepository {
private final ChainDefinition definition;
private FixedDefinitionRepository(ChainDefinition definition) {
this.definition = definition;
}
@Override
public ChainDefinition getChainDefinitionById(String id) {
return definition;
}
}
/**
* 每一轮都读取父循环节点上一轮的 current再计算新的结果。
*/
private static class AccumulatorNode extends BaseNode {
private final String loopNodeId;
private AccumulatorNode(String loopNodeId) {
this.loopNodeId = loopNodeId;
}
@Override
public Map<String, Object> execute(Chain chain) {
Object previous = chain.getState().resolveValue(loopNodeId + ".current");
Object indexValue = chain.getState().resolveValue(loopNodeId + ".index");
int index = indexValue instanceof Number
? ((Number) indexValue).intValue()
: Integer.parseInt(String.valueOf(indexValue));
if (index == 0) {
Assert.assertEquals("", previous);
} else if (index == 1) {
Assert.assertEquals("1", previous);
}
String current = index == 0
? "1"
: String.valueOf(Integer.parseInt(String.valueOf(previous)) + 1);
Map<String, Object> result = new HashMap<>();
result.put("current", current);
return result;
}
}
}

View File

@@ -4,6 +4,8 @@ import com.easyagents.flow.core.util.TextTemplate;
import org.junit.Assert;
import org.junit.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -21,4 +23,25 @@ public class TextTemplatePathTest {
Assert.assertEquals("你好啊", result);
}
@Test
public void shouldResolveFlatScopedNumericKeyFromMergedContext() {
Map<String, Object> memory = new HashMap<>();
memory.put("node_1.123", "7");
String result = TextTemplate.of("{{node_1.123}}")
.formatToString(Arrays.asList(memory, Collections.emptyMap()));
Assert.assertEquals("7", result);
}
@Test
public void shouldAllowResolvedEmptyStringWithoutTreatingItAsMissing() {
Map<String, Object> parameters = new HashMap<>();
parameters.put("node_1.current", "");
String result = TextTemplate.of("value={{node_1.current}}").formatToString(parameters);
Assert.assertEquals("value=", result);
}
}