fix: 修复循环节点模板上下文与累计输出
- 统一节点模板渲染上下文,补齐 memory、当前参数与环境变量 - 支持循环体读取父循环上一轮输出,并区分空字符串与缺失值 - 补充模板路径、上下文与循环累计场景回归测试
This commit is contained in:
@@ -370,6 +370,38 @@ public class ChainState implements Serializable {
|
|||||||
return formatArgsMap;
|
return formatArgsMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建文本模板渲染所需的上下文列表。
|
||||||
|
* <p>
|
||||||
|
* 顺序为:链路 memory -> 当前节点参数/临时参数 -> 环境变量。
|
||||||
|
* 后面的 map 会覆盖前面的同名 key,这样既能兼容直接引用历史节点输出,
|
||||||
|
* 又能保证当前节点显式解析出的参数优先级更高。
|
||||||
|
*
|
||||||
|
* @param formatArgs 当前节点参与模板渲染的参数
|
||||||
|
* @return 模板渲染上下文列表
|
||||||
|
*/
|
||||||
|
public List<Map<String, Object>> buildTemplateRootMaps(Map<String, Object> formatArgs) {
|
||||||
|
return Arrays.asList(getMemory(), formatArgs, getEnvMap());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建文本模板渲染所需的合并上下文。
|
||||||
|
*
|
||||||
|
* @param formatArgs 当前节点参与模板渲染的参数
|
||||||
|
* @return 合并后的模板上下文
|
||||||
|
*/
|
||||||
|
public Map<String, Object> buildTemplateContextMap(Map<String, Object> formatArgs) {
|
||||||
|
Map<String, Object> templateContext = new LinkedHashMap<>();
|
||||||
|
if (memory != null && !memory.isEmpty()) {
|
||||||
|
templateContext.putAll(memory);
|
||||||
|
}
|
||||||
|
if (formatArgs != null && !formatArgs.isEmpty()) {
|
||||||
|
templateContext.putAll(formatArgs);
|
||||||
|
}
|
||||||
|
templateContext.putAll(getEnvMap());
|
||||||
|
return templateContext;
|
||||||
|
}
|
||||||
|
|
||||||
public Map<String, Object> resolveParameters(Node node, List<? extends Parameter> parameters, Map<String, Object> formatArgs, boolean ignoreRequired) {
|
public Map<String, Object> resolveParameters(Node node, List<? extends Parameter> parameters, Map<String, Object> formatArgs, boolean ignoreRequired) {
|
||||||
if (parameters == null || parameters.isEmpty()) {
|
if (parameters == null || parameters.isEmpty()) {
|
||||||
return Collections.emptyMap();
|
return Collections.emptyMap();
|
||||||
@@ -381,7 +413,7 @@ public class ChainState implements Serializable {
|
|||||||
Object value = null;
|
Object value = null;
|
||||||
if (refType == RefType.FIXED) {
|
if (refType == RefType.FIXED) {
|
||||||
value = TextTemplate.of(parameter.getValue())
|
value = TextTemplate.of(parameter.getValue())
|
||||||
.formatToString(Arrays.asList(formatArgs, getEnvMap()));
|
.formatToString(buildTemplateRootMaps(formatArgs));
|
||||||
} else if (refType == RefType.REF) {
|
} else if (refType == RefType.REF) {
|
||||||
value = this.resolveValue(parameter.getRef());
|
value = this.resolveValue(parameter.getRef());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ public class CodeNode extends BaseNode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ChainState chainState = chain.getState();
|
ChainState chainState = chain.getState();
|
||||||
List<Map<String, Object>> variables = Arrays.asList(chainState.resolveParameters(this), chainState.getEnvMap());
|
Map<String, Object> parameterValues = chainState.resolveParameters(this);
|
||||||
String newCode = TextTemplate.of(code).formatToString(variables);
|
String newCode = TextTemplate.of(code).formatToString(chainState.buildTemplateRootMaps(parameterValues));
|
||||||
|
|
||||||
CodeRuntimeEngine codeRuntimeEngine = CodeRuntimeEngineManager.getInstance().getCodeRuntimeEngine(this.engine);
|
CodeRuntimeEngine codeRuntimeEngine = CodeRuntimeEngineManager.getInstance().getCodeRuntimeEngine(this.engine);
|
||||||
if (codeRuntimeEngine == null) {
|
if (codeRuntimeEngine == null) {
|
||||||
|
|||||||
@@ -213,7 +213,8 @@ public class HttpNode extends BaseNode {
|
|||||||
public Map<String, Object> doExecute(Chain chain) throws IOException {
|
public Map<String, Object> doExecute(Chain chain) throws IOException {
|
||||||
|
|
||||||
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
||||||
String newUrl = TextTemplate.of(url).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
|
String newUrl = TextTemplate.of(url)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
|
||||||
|
|
||||||
Request.Builder reqBuilder = new Request.Builder().url(newUrl);
|
Request.Builder reqBuilder = new Request.Builder().url(newUrl);
|
||||||
|
|
||||||
@@ -280,7 +281,8 @@ public class HttpNode extends BaseNode {
|
|||||||
|
|
||||||
private RequestBody getRequestBody(Chain chain, Map<String, Object> formatArgs) {
|
private RequestBody getRequestBody(Chain chain, Map<String, Object> formatArgs) {
|
||||||
if ("json".equals(bodyType)) {
|
if ("json".equals(bodyType)) {
|
||||||
String bodyJsonString = TextTemplate.of(bodyJson).formatToString(formatArgs, true);
|
String bodyJsonString = TextTemplate.of(bodyJson)
|
||||||
|
.formatToString(chain.getState().buildTemplateContextMap(formatArgs), true);
|
||||||
JSONObject jsonObject = JSON.parseObject(bodyJsonString);
|
JSONObject jsonObject = JSON.parseObject(bodyJsonString);
|
||||||
return RequestBody.create(jsonObject.toString(), MediaType.parse("application/json"));
|
return RequestBody.create(jsonObject.toString(), MediaType.parse("application/json"));
|
||||||
}
|
}
|
||||||
@@ -317,7 +319,8 @@ public class HttpNode extends BaseNode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ("raw".equals(bodyType)) {
|
if ("raw".equals(bodyType)) {
|
||||||
String rawBodyString = TextTemplate.of(rawBody).formatToString(Arrays.asList(formatArgs, chain.getState().getEnvMap()));
|
String rawBodyString = TextTemplate.of(rawBody)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(formatArgs));
|
||||||
return RequestBody.create(rawBodyString, null);
|
return RequestBody.create(rawBodyString, null);
|
||||||
}
|
}
|
||||||
//none
|
//none
|
||||||
|
|||||||
@@ -72,8 +72,10 @@ public class KnowledgeNode extends BaseNode {
|
|||||||
@Override
|
@Override
|
||||||
public Map<String, Object> execute(Chain chain) {
|
public Map<String, Object> execute(Chain chain) {
|
||||||
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
||||||
String realKeyword = TextTemplate.of(keyword).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
|
String realKeyword = TextTemplate.of(keyword)
|
||||||
String realLimitString = TextTemplate.of(limit).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
|
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
|
||||||
|
String realLimitString = TextTemplate.of(limit)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
|
||||||
int realLimit = 10;
|
int realLimit = 10;
|
||||||
if (StringUtil.hasText(realLimitString)) {
|
if (StringUtil.hasText(realLimitString)) {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -94,7 +94,8 @@ public class LlmNode extends BaseNode {
|
|||||||
throw new RuntimeException("Can not find user prompt");
|
throw new RuntimeException("Can not find user prompt");
|
||||||
}
|
}
|
||||||
|
|
||||||
String userPromptString = TextTemplate.of(userPrompt).formatToString(Arrays.asList(parameterValues, chain.getState().getEnvMap()));
|
String userPromptString = TextTemplate.of(userPrompt)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(parameterValues));
|
||||||
|
|
||||||
|
|
||||||
Llm llm = LlmManager.getInstance().getChatModel(this.llmId);
|
Llm llm = LlmManager.getInstance().getChatModel(this.llmId);
|
||||||
@@ -102,7 +103,8 @@ public class LlmNode extends BaseNode {
|
|||||||
throw new RuntimeException("Can not find llm: " + this.llmId);
|
throw new RuntimeException("Can not find llm: " + this.llmId);
|
||||||
}
|
}
|
||||||
|
|
||||||
String systemPromptString = TextTemplate.of(this.systemPrompt).formatToString(Arrays.asList(parameterValues, chain.getState().getEnvMap()));
|
String systemPromptString = TextTemplate.of(this.systemPrompt)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(parameterValues));
|
||||||
|
|
||||||
Llm.MessageInfo messageInfo = new Llm.MessageInfo();
|
Llm.MessageInfo messageInfo = new Llm.MessageInfo();
|
||||||
messageInfo.setMessage(userPromptString);
|
messageInfo.setMessage(userPromptString);
|
||||||
|
|||||||
@@ -71,6 +71,9 @@ public class LoopNode extends BaseNode {
|
|||||||
return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
|
return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
|
||||||
.set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
|
.set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 首次进入循环时,先为循环体暴露空的累计输出,方便子节点在第一轮读取到“空值”而不是缺参异常。
|
||||||
|
publishLoopProgress(chain, Collections.emptyMap());
|
||||||
}
|
}
|
||||||
// 由子节点返回:从堆栈低部获取当前循环上下文
|
// 由子节点返回:从堆栈低部获取当前循环上下文
|
||||||
else {
|
else {
|
||||||
@@ -105,7 +108,10 @@ public class LoopNode extends BaseNode {
|
|||||||
// 不是第一次执行,合并结果到 subResult
|
// 不是第一次执行,合并结果到 subResult
|
||||||
if (loopContext.currentIndex != 0) {
|
if (loopContext.currentIndex != 0) {
|
||||||
ChainState subState = chain.getState();
|
ChainState subState = chain.getState();
|
||||||
mergeResult(loopContext.subResult, subState);
|
Map<String, Object> currentOutputs = collectCurrentOutputValues(subState);
|
||||||
|
mergeResult(loopContext.subResult, currentOutputs);
|
||||||
|
// 将上一轮最新输出同步到循环节点作用域,供下一轮循环体读取。
|
||||||
|
publishLoopProgress(chain, currentOutputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -198,7 +204,8 @@ public class LoopNode extends BaseNode {
|
|||||||
* @param toResult 主流程的输出参数
|
* @param toResult 主流程的输出参数
|
||||||
* @param subState 子流程的
|
* @param subState 子流程的
|
||||||
*/
|
*/
|
||||||
private void mergeResult(Map<String, Object> toResult, ChainState subState) {
|
private Map<String, Object> collectCurrentOutputValues(ChainState subState) {
|
||||||
|
Map<String, Object> currentOutputs = new LinkedHashMap<>();
|
||||||
List<Parameter> outputDefs = getOutputDefs();
|
List<Parameter> outputDefs = getOutputDefs();
|
||||||
if (outputDefs != null) {
|
if (outputDefs != null) {
|
||||||
for (Parameter outputDef : outputDefs) {
|
for (Parameter outputDef : outputDefs) {
|
||||||
@@ -213,6 +220,19 @@ public class LoopNode extends BaseNode {
|
|||||||
value = outputDef.getValue();
|
value = outputDef.getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
currentOutputs.put(outputDef.getName(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return currentOutputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void mergeResult(Map<String, Object> toResult, Map<String, Object> currentOutputs) {
|
||||||
|
List<Parameter> outputDefs = getOutputDefs();
|
||||||
|
if (outputDefs != null) {
|
||||||
|
for (Parameter outputDef : outputDefs) {
|
||||||
|
Object value = currentOutputs.get(outputDef.getName());
|
||||||
|
|
||||||
@SuppressWarnings("unchecked") List<Object> existList = (List<Object>) toResult.get(outputDef.getName());
|
@SuppressWarnings("unchecked") List<Object> existList = (List<Object>) toResult.get(outputDef.getName());
|
||||||
if (existList == null) {
|
if (existList == null) {
|
||||||
existList = new ArrayList<>();
|
existList = new ArrayList<>();
|
||||||
@@ -224,6 +244,30 @@ public class LoopNode extends BaseNode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将循环当前累计值同步到父循环节点作用域。
|
||||||
|
* <p>
|
||||||
|
* 循环体内读取 `loopNodeId.outputName` 时,拿到的是上一轮该输出的最新值;
|
||||||
|
* 循环结束后,handleNodeResult 会再把最终聚合结果(列表)覆盖回同名作用域。
|
||||||
|
*/
|
||||||
|
private void publishLoopProgress(Chain chain, Map<String, Object> currentOutputs) {
|
||||||
|
List<Parameter> outputDefs = getOutputDefs();
|
||||||
|
if (outputDefs == null || outputDefs.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.updateStateSafely(state -> {
|
||||||
|
ConcurrentHashMap<String, Object> memory = state.getMemory();
|
||||||
|
for (Parameter outputDef : outputDefs) {
|
||||||
|
String key = this.id + "." + outputDef.getName();
|
||||||
|
Object value = currentOutputs.get(outputDef.getName());
|
||||||
|
memory.put(key, value == null ? "" : value);
|
||||||
|
}
|
||||||
|
return EnumSet.of(ChainStateField.MEMORY);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private String buildLoopStackId() {
|
private String buildLoopStackId() {
|
||||||
return this.getId() + "__loop__context";
|
return this.getId() + "__loop__context";
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,8 +63,10 @@ public class SearchEngineNode extends BaseNode {
|
|||||||
@Override
|
@Override
|
||||||
public Map<String, Object> execute(Chain chain) {
|
public Map<String, Object> execute(Chain chain) {
|
||||||
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
||||||
String realKeyword = TextTemplate.of(keyword).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
|
String realKeyword = TextTemplate.of(keyword)
|
||||||
String realLimitString = TextTemplate.of(limit).formatToString(Arrays.asList(argsMap, chain.getState().getEnvMap()));
|
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
|
||||||
|
String realLimitString = TextTemplate.of(limit)
|
||||||
|
.formatToString(chain.getState().buildTemplateRootMaps(argsMap));
|
||||||
int realLimit = 10;
|
int realLimit = 10;
|
||||||
if (StringUtil.hasText(realLimitString)) {
|
if (StringUtil.hasText(realLimitString)) {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -122,10 +122,11 @@ public class TextTemplate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 动态表达式求值
|
// 动态表达式求值
|
||||||
String value = evaluate(token.parseResult, rootMap, escapeForJsonOutput);
|
EvaluationResult evaluationResult = evaluate(token.parseResult, rootMap, escapeForJsonOutput);
|
||||||
|
String value = evaluationResult.value;
|
||||||
|
|
||||||
// 没有兜底且值为空时抛出异常
|
// 没有兜底且表达式完全未命中时抛出异常
|
||||||
if (!token.explicitEmptyFallback && value.isEmpty()) {
|
if (!token.explicitEmptyFallback && !evaluationResult.resolved) {
|
||||||
throw new IllegalArgumentException(String.format(
|
throw new IllegalArgumentException(String.format(
|
||||||
"Missing value for expression: \"%s\"%nTemplate: %s%nProvided parameters:%n%s",
|
"Missing value for expression: \"%s\"%nTemplate: %s%nProvided parameters:%n%s",
|
||||||
token.rawExpression,
|
token.rawExpression,
|
||||||
@@ -202,19 +203,19 @@ public class TextTemplate {
|
|||||||
/**
|
/**
|
||||||
* 递归求值表达式(支持多级兜底)
|
* 递归求值表达式(支持多级兜底)
|
||||||
*/
|
*/
|
||||||
private String evaluate(ParseResult pr, Map<String, Object> root, boolean escapeForJsonOutput) {
|
private EvaluationResult evaluate(ParseResult pr, Map<String, Object> root, boolean escapeForJsonOutput) {
|
||||||
if (pr == null) return "";
|
if (pr == null) return EvaluationResult.unresolved();
|
||||||
|
|
||||||
// 字面量直接返回
|
// 字面量直接返回
|
||||||
if (pr.isLiteral) {
|
if (pr.isLiteral) {
|
||||||
String literal = pr.getUnquotedLiteral();
|
String literal = pr.getUnquotedLiteral();
|
||||||
return escapeForJsonOutput ? escapeJsonString(literal) : literal;
|
return EvaluationResult.resolved(escapeForJsonOutput ? escapeJsonString(literal) : literal);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 尝试从 JSONPath 取值
|
// 尝试从 JSONPath 取值
|
||||||
Object value = getValueByJsonPath(root, pr.expression, escapeForJsonOutput);
|
Object value = getValueByJsonPath(root, pr.expression, escapeForJsonOutput);
|
||||||
if (value != null) {
|
if (value != null) {
|
||||||
return value.toString();
|
return EvaluationResult.resolved(value.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 若未取到,则尝试 fallback
|
// 若未取到,则尝试 fallback
|
||||||
@@ -361,4 +362,26 @@ public class TextTemplate {
|
|||||||
this.explicitEmptyFallback = explicitEmptyFallback;
|
this.explicitEmptyFallback = explicitEmptyFallback;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 表达式求值结果。
|
||||||
|
* resolved 表示表达式已成功命中,即使最终字符串为空,也不应视为缺参。
|
||||||
|
*/
|
||||||
|
private static class EvaluationResult {
|
||||||
|
final boolean resolved;
|
||||||
|
final String value;
|
||||||
|
|
||||||
|
private EvaluationResult(boolean resolved, String value) {
|
||||||
|
this.resolved = resolved;
|
||||||
|
this.value = value == null ? "" : value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static EvaluationResult resolved(String value) {
|
||||||
|
return new EvaluationResult(true, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
static EvaluationResult unresolved() {
|
||||||
|
return new EvaluationResult(false, "");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package com.easyagents.flow.core.test;
|
||||||
|
|
||||||
|
import com.easyagents.flow.core.chain.ChainState;
|
||||||
|
import com.easyagents.flow.core.chain.Parameter;
|
||||||
|
import com.easyagents.flow.core.chain.RefType;
|
||||||
|
import com.easyagents.flow.core.node.StartNode;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证固定值模板在运行时可以直接读取链路 memory 中的节点作用域变量。
|
||||||
|
*/
|
||||||
|
public class ChainTemplateContextTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldResolveScopedNumericMemoryKeyInFixedParameterTemplate() {
|
||||||
|
ChainState state = new ChainState();
|
||||||
|
state.getMemory().put("node_1.123", "7");
|
||||||
|
|
||||||
|
Parameter parameter = new Parameter();
|
||||||
|
parameter.setName("nextValue");
|
||||||
|
parameter.setRefType(RefType.FIXED);
|
||||||
|
parameter.setValue("{{node_1.123}}");
|
||||||
|
|
||||||
|
StartNode node = new StartNode();
|
||||||
|
node.setName("开始节点");
|
||||||
|
|
||||||
|
Map<String, Object> result = state.resolveParameters(node, Collections.singletonList(parameter));
|
||||||
|
|
||||||
|
Assert.assertEquals("7", result.get("nextValue"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,145 @@
|
|||||||
|
package com.easyagents.flow.core.test;
|
||||||
|
|
||||||
|
import com.easyagents.flow.core.chain.*;
|
||||||
|
import com.easyagents.flow.core.chain.repository.ChainDefinitionRepository;
|
||||||
|
import com.easyagents.flow.core.chain.repository.InMemoryChainStateRepository;
|
||||||
|
import com.easyagents.flow.core.chain.repository.InMemoryNodeStateRepository;
|
||||||
|
import com.easyagents.flow.core.chain.runtime.ChainExecutor;
|
||||||
|
import com.easyagents.flow.core.node.BaseNode;
|
||||||
|
import com.easyagents.flow.core.node.EndNode;
|
||||||
|
import com.easyagents.flow.core.node.LoopNode;
|
||||||
|
import com.easyagents.flow.core.node.StartNode;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证循环体内可以读取父循环节点上一轮的最新输出。
|
||||||
|
*/
|
||||||
|
public class LoopNodeProgressContextTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldExposeLatestLoopOutputInsideLoopBody() {
|
||||||
|
ChainDefinition definition = new ChainDefinition();
|
||||||
|
definition.setId("loop-progress-test");
|
||||||
|
|
||||||
|
StartNode startNode = new StartNode();
|
||||||
|
startNode.setId("start");
|
||||||
|
startNode.setName("开始节点");
|
||||||
|
startNode.setParameters(Collections.singletonList(inputParameter("times")));
|
||||||
|
|
||||||
|
LoopNode loopNode = new LoopNode();
|
||||||
|
loopNode.setId("loop");
|
||||||
|
loopNode.setName("循环节点");
|
||||||
|
Parameter loopVar = new Parameter();
|
||||||
|
loopVar.setName("times");
|
||||||
|
loopVar.setRef("times");
|
||||||
|
loopVar.setRefType(RefType.REF);
|
||||||
|
loopNode.setLoopVar(loopVar);
|
||||||
|
|
||||||
|
AccumulatorNode accumulatorNode = new AccumulatorNode("loop");
|
||||||
|
accumulatorNode.setId("acc");
|
||||||
|
accumulatorNode.setName("累计节点");
|
||||||
|
accumulatorNode.setParentId("loop");
|
||||||
|
|
||||||
|
Parameter loopOutput = new Parameter();
|
||||||
|
loopOutput.setName("current");
|
||||||
|
loopOutput.setRef("acc.current");
|
||||||
|
loopOutput.setRefType(RefType.REF);
|
||||||
|
loopNode.setOutputDefs(Collections.singletonList(loopOutput));
|
||||||
|
|
||||||
|
EndNode endNode = new EndNode();
|
||||||
|
endNode.setId("end");
|
||||||
|
endNode.setName("结束节点");
|
||||||
|
Parameter result = new Parameter();
|
||||||
|
result.setName("result");
|
||||||
|
result.setRef("loop.current");
|
||||||
|
result.setRefType(RefType.REF);
|
||||||
|
endNode.setOutputDefs(Collections.singletonList(result));
|
||||||
|
|
||||||
|
definition.addNode(startNode);
|
||||||
|
definition.addNode(loopNode);
|
||||||
|
definition.addNode(accumulatorNode);
|
||||||
|
definition.addNode(endNode);
|
||||||
|
definition.addEdge(edge("e1", "start", "loop"));
|
||||||
|
definition.addEdge(edge("e2", "loop", "acc"));
|
||||||
|
definition.addEdge(edge("e3", "loop", "end"));
|
||||||
|
|
||||||
|
ChainExecutor executor = new ChainExecutor(new FixedDefinitionRepository(definition),
|
||||||
|
new InMemoryChainStateRepository(),
|
||||||
|
new InMemoryNodeStateRepository());
|
||||||
|
|
||||||
|
Map<String, Object> variables = new HashMap<>();
|
||||||
|
variables.put("times", 2);
|
||||||
|
|
||||||
|
Map<String, Object> resultMap = executor.execute("loop-progress-test", variables);
|
||||||
|
|
||||||
|
Assert.assertEquals(java.util.Arrays.asList("1", "2"), resultMap.get("result"));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Parameter inputParameter(String name) {
|
||||||
|
Parameter parameter = new Parameter();
|
||||||
|
parameter.setName(name);
|
||||||
|
parameter.setRefType(RefType.INPUT);
|
||||||
|
parameter.setRequired(true);
|
||||||
|
return parameter;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Edge edge(String id, String source, String target) {
|
||||||
|
Edge edge = new Edge();
|
||||||
|
edge.setId(id);
|
||||||
|
edge.setSource(source);
|
||||||
|
edge.setTarget(target);
|
||||||
|
return edge;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class FixedDefinitionRepository implements ChainDefinitionRepository {
|
||||||
|
private final ChainDefinition definition;
|
||||||
|
|
||||||
|
private FixedDefinitionRepository(ChainDefinition definition) {
|
||||||
|
this.definition = definition;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChainDefinition getChainDefinitionById(String id) {
|
||||||
|
return definition;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 每一轮都读取父循环节点上一轮的 current,再计算新的结果。
|
||||||
|
*/
|
||||||
|
private static class AccumulatorNode extends BaseNode {
|
||||||
|
private final String loopNodeId;
|
||||||
|
|
||||||
|
private AccumulatorNode(String loopNodeId) {
|
||||||
|
this.loopNodeId = loopNodeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> execute(Chain chain) {
|
||||||
|
Object previous = chain.getState().resolveValue(loopNodeId + ".current");
|
||||||
|
Object indexValue = chain.getState().resolveValue(loopNodeId + ".index");
|
||||||
|
int index = indexValue instanceof Number
|
||||||
|
? ((Number) indexValue).intValue()
|
||||||
|
: Integer.parseInt(String.valueOf(indexValue));
|
||||||
|
|
||||||
|
if (index == 0) {
|
||||||
|
Assert.assertEquals("", previous);
|
||||||
|
} else if (index == 1) {
|
||||||
|
Assert.assertEquals("1", previous);
|
||||||
|
}
|
||||||
|
|
||||||
|
String current = index == 0
|
||||||
|
? "1"
|
||||||
|
: String.valueOf(Integer.parseInt(String.valueOf(previous)) + 1);
|
||||||
|
|
||||||
|
Map<String, Object> result = new HashMap<>();
|
||||||
|
result.put("current", current);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,8 @@ import com.easyagents.flow.core.util.TextTemplate;
|
|||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@@ -21,4 +23,25 @@ public class TextTemplatePathTest {
|
|||||||
|
|
||||||
Assert.assertEquals("你好啊", result);
|
Assert.assertEquals("你好啊", result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldResolveFlatScopedNumericKeyFromMergedContext() {
|
||||||
|
Map<String, Object> memory = new HashMap<>();
|
||||||
|
memory.put("node_1.123", "7");
|
||||||
|
|
||||||
|
String result = TextTemplate.of("{{node_1.123}}")
|
||||||
|
.formatToString(Arrays.asList(memory, Collections.emptyMap()));
|
||||||
|
|
||||||
|
Assert.assertEquals("7", result);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldAllowResolvedEmptyStringWithoutTreatingItAsMissing() {
|
||||||
|
Map<String, Object> parameters = new HashMap<>();
|
||||||
|
parameters.put("node_1.current", "");
|
||||||
|
|
||||||
|
String result = TextTemplate.of("value={{node_1.current}}").formatToString(parameters);
|
||||||
|
|
||||||
|
Assert.assertEquals("value=", result);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user