feat: 完成管理端聊天工作台收口

- 新增管理端聊天工作台与会话级额外知识库持久化

- 补齐发布态聊天、历史会话只读判断与答案版本切换

- 新增 chat_round 热数据与主线消息读取支撑
This commit is contained in:
2026-05-14 20:22:46 +08:00
parent 2ad8935a61
commit 47c2bad839
63 changed files with 8609 additions and 136 deletions

View File

@@ -20,9 +20,11 @@ import tech.easyflow.core.chat.protocol.payload.ErrorPayload;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import tech.easyflow.core.runtime.ChatAssistantAccumulator;
import tech.easyflow.core.runtime.ChatRuntimeContext;
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
import tech.easyflow.core.runtime.ChatRuntimeManager;
import tech.easyflow.core.runtime.ChatRuntimeMessage;
import java.math.BigInteger;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
@@ -176,6 +178,7 @@ public class ChatStreamListener implements StreamResponseListener {
deltaMap.put("role", MessageRole.ASSISTANT.getValue());
deltaMap.put("delta", deltaContent);
chatEnvelope.setPayload(deltaMap);
chatEnvelope.setMeta(buildStreamMeta());
boolean sent = sseEmitter.send(chatEnvelope);
if (!sent) {
@@ -196,6 +199,7 @@ public class ChatStreamListener implements StreamResponseListener {
payload.put("name", toolCall.getName());
payload.put("arguments", toolCall.getArguments());
chatEnvelope.setPayload(payload);
chatEnvelope.setMeta(buildStreamMeta());
boolean sent = sseEmitter.send(chatEnvelope);
if (!sent) {
throw new IllegalStateException("SSE emitter has already completed while sending tool call envelope");
@@ -214,6 +218,7 @@ public class ChatStreamListener implements StreamResponseListener {
payload.put("tool_call_id", toolMessage.getToolCallId());
payload.put("result", toolMessage.getContent());
chatEnvelope.setPayload(payload);
chatEnvelope.setMeta(buildStreamMeta());
boolean sent = sseEmitter.send(chatEnvelope);
if (!sent) {
throw new IllegalStateException("SSE emitter has already completed while sending tool result envelope");
@@ -236,6 +241,7 @@ public class ChatStreamListener implements StreamResponseListener {
envelope.setPayload(payload);
envelope.setDomain(ChatDomain.SYSTEM);
envelope.setType(ChatType.ERROR);
envelope.setMeta(buildStreamMeta());
boolean sent = sseEmitter.sendError(envelope);
if (!sent) {
LOG.warn("sendSystemError skipped because emitter is closed, conversationId={}", conversationId);
@@ -243,6 +249,54 @@ public class ChatStreamListener implements StreamResponseListener {
sseEmitter.complete();
}
private Map<String, Object> buildStreamMeta() {
Map<String, Object> meta = new LinkedHashMap<>();
BigInteger roundId = getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID);
Integer roundNo = getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO);
Integer variantIndex = getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX);
BigInteger regenerateRoundId = getBigIntegerExt(ChatRuntimeExtKeys.REGENERATE_ROUND_ID);
if (roundId != null) {
meta.put("roundId", roundId.toString());
}
if (roundNo != null) {
meta.put("roundNo", roundNo);
}
if (variantIndex != null) {
meta.put("variantIndex", variantIndex);
}
meta.put("regenerate", regenerateRoundId != null);
if (regenerateRoundId != null) {
meta.put("regenerateRoundId", regenerateRoundId.toString());
}
return meta;
}
private BigInteger getBigIntegerExt(String key) {
Object value = runtimeContext == null || runtimeContext.getExt() == null
? null
: runtimeContext.getExt().get(key);
if (value == null) {
return null;
}
if (value instanceof BigInteger number) {
return number;
}
return new BigInteger(String.valueOf(value));
}
private Integer getIntegerExt(String key) {
Object value = runtimeContext == null || runtimeContext.getExt() == null
? null
: runtimeContext.getExt().get(key);
if (value == null) {
return null;
}
if (value instanceof Integer number) {
return number;
}
return Integer.parseInt(String.valueOf(value));
}
private void stopStreamClient(StreamContext context, String reason, Throwable source) {
try {
if (context != null && context.getClient() != null) {
@@ -267,6 +321,7 @@ public class ChatStreamListener implements StreamResponseListener {
message.setCreatedAt(new Date());
message.setSenderId(runtimeContext.getAssistantId());
message.setSenderName(runtimeContext.getAssistantName());
applyRoundMetadata(message);
return message;
}
@@ -280,9 +335,24 @@ public class ChatStreamListener implements StreamResponseListener {
message.setCreatedAt(new Date());
message.setSenderId(runtimeContext.getAssistantId());
message.setSenderName(runtimeContext.getAssistantName());
applyRoundMetadata(message);
return message;
}
/**
* 把当前轮次元数据写回 assistant 消息,确保 SSE 与后续持久化链路都能识别轮次和版本。
*
* @param message assistant 运行时消息
*/
private void applyRoundMetadata(ChatRuntimeMessage message) {
if (message == null) {
return;
}
message.setRoundId(getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID));
message.setRoundNo(getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO));
message.setVariantIndex(getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX));
}
/**
* 在 tool call assistant 写入临时 memory 前,把 reasoning/content 快照回填到消息对象中,
* 以便前端 history 透传和 DeepSeek 下一轮请求都能拿到完整链路。

View File

@@ -56,6 +56,25 @@ public interface BotService extends IService<Bot> {
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, BotServiceImpl.ChatCheckResult chatCheckResult);
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
BotServiceImpl.ChatCheckResult chatCheckResult, BigInteger regenerateRoundId);
/**
* 聊天前置校验,并根据调用方要求决定是否强制走发布态。
*
* @param botId 聊天助手 ID
* @param prompt 用户问题
* @param conversationId 会话 ID
* @param chatCheckResult 校验结果承载对象
* @param publishedOnly 是否强制走发布态
* @return 校验失败时返回错误 SSE成功时返回 {@code null}
*/
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly);
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId);
SseEmitter startChat(BigInteger botId, String prompt, BigInteger conversationId, List<Map<String, String>> messages,
BotServiceImpl.ChatCheckResult chatCheckResult, List<String> attachments, ChatRuntimeContext runtimeContext);

View File

@@ -13,6 +13,7 @@ import com.easyagents.core.model.chat.ChatOptions;
import com.easyagents.core.model.chat.StreamResponseListener;
import com.easyagents.core.model.chat.tool.Tool;
import com.easyagents.core.prompt.MemoryPrompt;
import com.easyagents.rag.retrieval.RetrievalMode;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.spring.service.impl.ServiceImpl;
import org.slf4j.Logger;
@@ -51,6 +52,7 @@ import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import tech.easyflow.core.chat.protocol.sse.ChatSseUtil;
import tech.easyflow.core.runtime.ChatAssistantAccumulator;
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
import tech.easyflow.core.runtime.ChatRuntimeContext;
import tech.easyflow.core.runtime.ChatRuntimeManager;
import tech.easyflow.core.runtime.ChatRuntimeMessage;
@@ -60,9 +62,11 @@ import javax.annotation.Resource;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static tech.easyflow.ai.entity.table.BotPluginTableDef.BOT_PLUGIN;
import static tech.easyflow.ai.entity.table.PluginItemTableDef.PLUGIN_ITEM;
@@ -78,6 +82,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
private static final Logger log = LoggerFactory.getLogger(BotServiceImpl.class);
private static final String FAQ_IMAGE_SYSTEM_RULE = "当知识工具返回 Markdown 图片(格式:![描述](URL))时,你必须在最终回答中保留并输出对应的图片 Markdown禁止改写、替换或省略图片 URL。";
private static final String EXTRA_KNOWLEDGE_PRIORITY_RULE = "若当前会话显式选择了额外知识库,请优先参考这些额外知识库;仅在额外知识库无法回答时,再回退到助手默认绑定知识库。";
private static final int MAX_EXTRA_KNOWLEDGE_COUNT = 3;
public static class ChatCheckResult {
private Bot aiBot;
@@ -215,6 +221,24 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
}
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, ChatCheckResult chatCheckResult) {
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, null);
}
@Override
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
ChatCheckResult chatCheckResult, BigInteger regenerateRoundId) {
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, regenerateRoundId);
}
@Override
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
ChatCheckResult chatCheckResult, boolean publishedOnly) {
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, publishedOnly, null);
}
@Override
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId) {
if (!StringUtils.hasLength(prompt)) {
return ChatSseUtil.sendSystemError(conversationId, "提示词不能为空");
}
@@ -248,7 +272,13 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
if ((!login || anonymousAccount) && !aiBot.isAnonymousEnabled()) {
return ChatSseUtil.sendSystemError(conversationId, "此聊天助手不支持匿名访问");
}
if (!login || anonymousAccount) {
if (publishedOnly && login && !anonymousAccount) {
if (PublishStatus.from(aiBot.getPublishStatus()) != PublishStatus.PUBLISHED) {
return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布");
}
aiBot = toPublishedView(aiBot);
chatCheckResult.setPublishedAccess(true);
} else if (!login || anonymousAccount) {
Bot publishedBot = toPublishedView(aiBot);
if (!PublishStatus.from(aiBot.getPublishStatus()).isExternallyVisible()) {
return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布");
@@ -267,7 +297,6 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
if (chatModel == null) {
return ChatSseUtil.sendSystemError(conversationId, "对话模型获取失败,请检查配置");
}
chatCheckResult.setAiBot(aiBot);
chatCheckResult.setModelOptions(modelOptions);
chatCheckResult.setChatModel(chatModel);
@@ -282,8 +311,9 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
ChatModel chatModel = chatCheckResult.getChatModel();
ChatTimeToolAvailabilityContext chatTimeContext = buildChatTimeToolAvailabilityContext(runtimeContext, chatCheckResult.getAiBot());
final MemoryPrompt memoryPrompt = new MemoryPrompt();
String systemPrompt = buildSystemPromptWithFaqImageRule(
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT)
String systemPrompt = buildSystemPrompt(
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT),
runtimeContext
);
Integer maxMessageCount = MapUtil.getInteger(modelOptions, Bot.KEY_MAX_MESSAGE_COUNT);
if (maxMessageCount != null) {
@@ -301,7 +331,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
.set("needEnglishName", true)
.set("bot", chatCheckResult.getAiBot())
.set("chatTimeContext", chatTimeContext)
.set("publishedOnly", chatCheckResult.isPublishedAccess())));
.set("publishedOnly", chatCheckResult.isPublishedAccess())
.set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext))));
ChatOptions chatOptions = getChatOptions(modelOptions);
Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false);
chatOptions.setThinkingEnabled(enableDeepThinking);
@@ -353,8 +384,9 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false);
chatOptions.setThinkingEnabled(enableDeepThinking);
ChatModel chatModel = chatCheckResult.getChatModel();
String systemPrompt = buildSystemPromptWithFaqImageRule(
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT)
String systemPrompt = buildSystemPrompt(
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT),
runtimeContext
);
UserMessage userMessage = new UserMessage(prompt);
userMessage.addTools(buildFunctionList(Maps.of("botId", botId)
@@ -362,6 +394,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
.set("needAccountId", false)
.set("bot", chatCheckResult.getAiBot())
.set("publishedOnly", chatCheckResult.isPublishedAccess())
.set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext))
));
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
SseEmitter emitter = chatSseEmitter.getEmitter();
@@ -474,15 +507,18 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
}
Bot runtimeBot = (Bot) buildParams.get("bot");
ChatTimeToolAvailabilityContext chatTimeContext = (ChatTimeToolAvailabilityContext) buildParams.get("chatTimeContext");
List<BigInteger> extraKnowledgeIds = sanitizeExtraKnowledgeIds((List<BigInteger>) buildParams.get("extraKnowledgeIds"));
boolean usePublishedSnapshot = Boolean.TRUE.equals(buildParams.get("publishedOnly"))
&& runtimeBot != null
&& runtimeBot.getPublishedSnapshotJson() != null
&& PublishStatus.from(runtimeBot.getPublishStatus()).isExternallyVisible();
QueryWrapper queryWrapper = QueryWrapper.create();
Set<BigInteger> existingKnowledgeIds = new LinkedHashSet<>();
appendExtraKnowledgeTools(functionList, extraKnowledgeIds, needEnglishName, chatTimeContext, existingKnowledgeIds);
if (usePublishedSnapshot) {
appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext, existingKnowledgeIds);
appendPublishedWorkflowTools(functionList, runtimeBot, needEnglishName);
appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext);
} else {
// 工作流 function 集合
queryWrapper.eq(BotWorkflow::getBotId, botId);
@@ -500,7 +536,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
queryWrapper.eq(BotDocumentCollection::getBotId, botId);
List<BotDocumentCollection> botDocumentCollections = botDocumentCollectionService.getMapper()
.selectListWithRelationsByQuery(queryWrapper);
functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext));
functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, existingKnowledgeIds));
}
// 插件 function 集合
@@ -550,6 +586,22 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
List<Tool> buildKnowledgeTools(List<BotDocumentCollection> botDocumentCollections,
boolean needEnglishName,
ChatTimeToolAvailabilityContext chatTimeContext) {
return buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, new LinkedHashSet<>());
}
/**
* 将 Bot 绑定的知识库候选项收敛为当前聊天可用的工具列表,并与已选临时知识库去重。
*
* @param botDocumentCollections Bot 知识库绑定项
* @param needEnglishName 是否使用英文名称
* @param chatTimeContext 聊天时权限上下文
* @param existingKnowledgeIds 已装配知识库 ID 集
* @return 知识库工具列表
*/
List<Tool> buildKnowledgeTools(List<BotDocumentCollection> botDocumentCollections,
boolean needEnglishName,
ChatTimeToolAvailabilityContext chatTimeContext,
Set<BigInteger> existingKnowledgeIds) {
List<Tool> functionList = new ArrayList<>();
if (botDocumentCollections == null || botDocumentCollections.isEmpty()) {
return functionList;
@@ -559,7 +611,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
: chatTimeToolAvailabilityService.filterAvailable(chatTimeContext, botDocumentCollections);
for (BotDocumentCollection botDocumentCollection : availableBindings) {
DocumentCollection knowledge = botDocumentCollection.getKnowledge();
if (knowledge == null) {
if (knowledge == null || isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) {
continue;
}
DocumentCollectionTool function = (DocumentCollectionTool) knowledge.toFunction(
@@ -603,7 +655,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
private void appendPublishedKnowledgeTools(List<Tool> functionList,
Bot runtimeBot,
boolean needEnglishName,
ChatTimeToolAvailabilityContext chatTimeContext) {
ChatTimeToolAvailabilityContext chatTimeContext,
Set<BigInteger> existingKnowledgeIds) {
Object knowledges = runtimeBot.getPublishedSnapshotJson().get("knowledgeBindings");
if (!(knowledges instanceof List<?> knowledgeBindings)) {
return;
@@ -617,12 +670,15 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
continue;
}
DocumentCollection knowledge = documentCollectionService.getPublishedById(new BigInteger(String.valueOf(knowledgeId)));
if (knowledge == null) {
if (knowledge == null || PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) {
continue;
}
if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) {
continue;
}
if (isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) {
continue;
}
Object retrievalMode = bindingMap.get("retrievalMode");
functionList.add(knowledge.toFunction(
needEnglishName,
@@ -632,6 +688,42 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
}
}
/**
* 组装会话级临时知识库工具,并按用户选择顺序优先插入。
*
* @param functionList 工具集合
* @param extraKnowledgeIds 额外知识库 ID
* @param needEnglishName 是否使用英文名称
* @param chatTimeContext 聊天时权限上下文
* @param existingKnowledgeIds 已装配知识库 ID 集
*/
protected void appendExtraKnowledgeTools(List<Tool> functionList,
List<BigInteger> extraKnowledgeIds,
boolean needEnglishName,
ChatTimeToolAvailabilityContext chatTimeContext,
Set<BigInteger> existingKnowledgeIds) {
if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) {
return;
}
for (BigInteger knowledgeId : extraKnowledgeIds) {
if (knowledgeId == null || isDuplicateKnowledge(existingKnowledgeIds, knowledgeId)) {
continue;
}
DocumentCollection knowledge = documentCollectionService.getById(knowledgeId);
if (knowledge == null) {
throw new BusinessException("额外知识库不存在");
}
if (PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) {
throw new BusinessException("额外知识库未发布,无法用于聊天");
}
if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) {
throw new BusinessException("当前用户无权使用所选知识库");
}
functionList.add(documentCollectionService.toPublishedView(knowledge)
.toFunction(needEnglishName, RetrievalMode.HYBRID.name(), chatTimeContext));
}
}
/**
* 构造聊天时工具可用性上下文,并显式回填到运行时上下文中供后续异步工具调用复用。
*
@@ -704,14 +796,64 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
return context.getUserAccount();
}
private String buildSystemPromptWithFaqImageRule(String systemPrompt) {
if (!StringUtils.hasLength(systemPrompt)) {
return FAQ_IMAGE_SYSTEM_RULE;
private String buildSystemPrompt(String systemPrompt, ChatRuntimeContext runtimeContext) {
String mergedPrompt = appendPromptRule(systemPrompt, FAQ_IMAGE_SYSTEM_RULE);
if (hasExtraKnowledgeSelection(runtimeContext)) {
mergedPrompt = appendPromptRule(mergedPrompt, EXTRA_KNOWLEDGE_PRIORITY_RULE);
}
if (systemPrompt.contains(FAQ_IMAGE_SYSTEM_RULE)) {
return mergedPrompt;
}
private String appendPromptRule(String systemPrompt, String rule) {
if (!StringUtils.hasLength(systemPrompt)) {
return rule;
}
if (systemPrompt.contains(rule)) {
return systemPrompt;
}
return systemPrompt + "\n\n" + FAQ_IMAGE_SYSTEM_RULE;
return systemPrompt + "\n\n" + rule;
}
private List<BigInteger> resolveExtraKnowledgeIds(ChatRuntimeContext runtimeContext) {
if (runtimeContext == null || runtimeContext.getExt() == null) {
return List.of();
}
Object rawValue = runtimeContext.getExt().get(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS);
if (!(rawValue instanceof List<?> rawList) || rawList.isEmpty()) {
return List.of();
}
List<BigInteger> values = new ArrayList<>(rawList.size());
for (Object item : rawList) {
if (item == null) {
continue;
}
values.add(new BigInteger(String.valueOf(item)));
}
return values;
}
private List<BigInteger> sanitizeExtraKnowledgeIds(List<BigInteger> extraKnowledgeIds) {
if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) {
return List.of();
}
LinkedHashSet<BigInteger> dedup = new LinkedHashSet<>();
for (BigInteger knowledgeId : extraKnowledgeIds) {
if (knowledgeId != null) {
dedup.add(knowledgeId);
}
}
if (dedup.size() > MAX_EXTRA_KNOWLEDGE_COUNT) {
throw new BusinessException("额外知识库最多选择 3 个");
}
return new ArrayList<>(dedup);
}
private boolean hasExtraKnowledgeSelection(ChatRuntimeContext runtimeContext) {
return !resolveExtraKnowledgeIds(runtimeContext).isEmpty();
}
private boolean isDuplicateKnowledge(Set<BigInteger> existingKnowledgeIds, BigInteger knowledgeId) {
return knowledgeId != null && existingKnowledgeIds != null && !existingKnowledgeIds.add(knowledgeId);
}
}

View File

@@ -121,11 +121,12 @@ public class BotServiceImplTest {
List.class,
Bot.class,
boolean.class,
ChatTimeToolAvailabilityContext.class
ChatTimeToolAvailabilityContext.class,
Set.class
);
method.setAccessible(true);
method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3));
method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3), new LinkedHashSet<>());
Assert.assertEquals(1, functionList.size());
DocumentCollectionTool tool = (DocumentCollectionTool) functionList.get(0);
@@ -133,6 +134,102 @@ public class BotServiceImplTest {
Assert.assertEquals(RetrievalMode.KEYWORD, tool.getRetrievalMode());
}
/**
* 额外知识库应按用户选择顺序去重,并限制最多 3 个。
*
* @throws Exception 反射调用异常
*/
@Test
@SuppressWarnings("unchecked")
public void sanitizeExtraKnowledgeIdsShouldDeduplicateAndKeepOrder() throws Exception {
BotServiceImpl service = new BotServiceImpl();
Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class);
method.setAccessible(true);
List<BigInteger> result = (List<BigInteger>) method.invoke(
service,
List.of(
BigInteger.valueOf(3),
BigInteger.valueOf(1),
BigInteger.valueOf(3),
BigInteger.valueOf(2)
)
);
Assert.assertEquals(
List.of(BigInteger.valueOf(3), BigInteger.ONE, BigInteger.valueOf(2)),
result
);
}
/**
* 额外知识库超过 3 个时应直接拒绝请求。
*
* @throws Exception 反射调用异常
*/
@Test
public void sanitizeExtraKnowledgeIdsShouldRejectWhenMoreThanThree() throws Exception {
BotServiceImpl service = new BotServiceImpl();
Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class);
method.setAccessible(true);
try {
method.invoke(
service,
List.of(
BigInteger.ONE,
BigInteger.valueOf(2),
BigInteger.valueOf(3),
BigInteger.valueOf(4)
)
);
Assert.fail("expected BusinessException");
} catch (java.lang.reflect.InvocationTargetException exception) {
Assert.assertTrue(exception.getTargetException() instanceof tech.easyflow.common.web.exceptions.BusinessException);
Assert.assertEquals("额外知识库最多选择 3 个", exception.getTargetException().getMessage());
}
}
/**
* 额外知识库工具应强制使用发布态视图,并跳过重复选择。
*
* @throws Exception 反射注入异常
*/
@Test
public void appendExtraKnowledgeToolsShouldUsePublishedKnowledgeAndDeduplicate() throws Exception {
BotServiceImpl service = new BotServiceImpl();
injectAvailabilityService(service, buildAvailabilityService(
new RoleCategoryAccessSnapshot("KNOWLEDGE", BigInteger.valueOf(12), false, false, setOf(BigInteger.valueOf(21))),
Collections.emptySet()
));
DocumentCollection draftKnowledge = buildKnowledge(101, 11, 21, "PUBLIC");
draftKnowledge.setPublishStatus("PUBLISHED");
draftKnowledge.setTitle("draft-title");
draftKnowledge.setDescription("draft-desc");
DocumentCollection publishedKnowledge = buildKnowledge(101, 11, 21, "PUBLIC");
publishedKnowledge.setPublishStatus("PUBLISHED");
publishedKnowledge.setTitle("published-title");
publishedKnowledge.setDescription("published-desc");
injectDocumentCollectionService(service, mockDocumentCollectionServiceForExtra(draftKnowledge, publishedKnowledge));
List<Tool> tools = new ArrayList<>();
service.appendExtraKnowledgeTools(
tools,
List.of(BigInteger.valueOf(101), BigInteger.valueOf(101)),
false,
buildContext(12, 3),
new LinkedHashSet<>()
);
Assert.assertEquals(1, tools.size());
DocumentCollectionTool tool = (DocumentCollectionTool) tools.get(0);
Assert.assertEquals(BigInteger.valueOf(101), tool.getKnowledgeId());
Assert.assertEquals("published-desc", tool.getDescription());
}
private void injectAvailabilityService(BotServiceImpl service, ChatTimeToolAvailabilityService availabilityService) throws Exception {
Field field = BotServiceImpl.class.getDeclaredField("chatTimeToolAvailabilityService");
field.setAccessible(true);
@@ -189,6 +286,7 @@ public class BotServiceImplTest {
knowledge.setCreatedBy(BigInteger.valueOf(createdBy));
knowledge.setCategoryId(BigInteger.valueOf(categoryId));
knowledge.setVisibilityScope(visibilityScope);
knowledge.setPublishStatus("PUBLISHED");
return knowledge;
}
@@ -216,6 +314,26 @@ public class BotServiceImplTest {
);
}
private tech.easyflow.ai.service.DocumentCollectionService mockDocumentCollectionServiceForExtra(DocumentCollection draftCollection,
DocumentCollection publishedCollection) {
return (tech.easyflow.ai.service.DocumentCollectionService) Proxy.newProxyInstance(
tech.easyflow.ai.service.DocumentCollectionService.class.getClassLoader(),
new Class<?>[]{tech.easyflow.ai.service.DocumentCollectionService.class},
(proxy, method, args) -> {
if ("getById".equals(method.getName())) {
return draftCollection;
}
if ("toPublishedView".equals(method.getName())) {
return publishedCollection;
}
if ("getPublishedById".equals(method.getName())) {
return publishedCollection;
}
return defaultValue(method.getReturnType());
}
);
}
private CategoryPermissionService mockCategoryPermissionService(RoleCategoryAccessSnapshot accessSnapshot) {
return (CategoryPermissionService) Proxy.newProxyInstance(
CategoryPermissionService.class.getClassLoader(),