feat: 完成管理端聊天工作台收口
- 新增管理端聊天工作台与会话级额外知识库持久化 - 补齐发布态聊天、历史会话只读判断与答案版本切换 - 新增 chat_round 热数据与主线消息读取支撑
This commit is contained in:
@@ -20,9 +20,11 @@ import tech.easyflow.core.chat.protocol.payload.ErrorPayload;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
import tech.easyflow.core.runtime.ChatAssistantAccumulator;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeContext;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeManager;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeMessage;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
@@ -176,6 +178,7 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
deltaMap.put("role", MessageRole.ASSISTANT.getValue());
|
||||
deltaMap.put("delta", deltaContent);
|
||||
chatEnvelope.setPayload(deltaMap);
|
||||
chatEnvelope.setMeta(buildStreamMeta());
|
||||
|
||||
boolean sent = sseEmitter.send(chatEnvelope);
|
||||
if (!sent) {
|
||||
@@ -196,6 +199,7 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
payload.put("name", toolCall.getName());
|
||||
payload.put("arguments", toolCall.getArguments());
|
||||
chatEnvelope.setPayload(payload);
|
||||
chatEnvelope.setMeta(buildStreamMeta());
|
||||
boolean sent = sseEmitter.send(chatEnvelope);
|
||||
if (!sent) {
|
||||
throw new IllegalStateException("SSE emitter has already completed while sending tool call envelope");
|
||||
@@ -214,6 +218,7 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
payload.put("tool_call_id", toolMessage.getToolCallId());
|
||||
payload.put("result", toolMessage.getContent());
|
||||
chatEnvelope.setPayload(payload);
|
||||
chatEnvelope.setMeta(buildStreamMeta());
|
||||
boolean sent = sseEmitter.send(chatEnvelope);
|
||||
if (!sent) {
|
||||
throw new IllegalStateException("SSE emitter has already completed while sending tool result envelope");
|
||||
@@ -236,6 +241,7 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
envelope.setPayload(payload);
|
||||
envelope.setDomain(ChatDomain.SYSTEM);
|
||||
envelope.setType(ChatType.ERROR);
|
||||
envelope.setMeta(buildStreamMeta());
|
||||
boolean sent = sseEmitter.sendError(envelope);
|
||||
if (!sent) {
|
||||
LOG.warn("sendSystemError skipped because emitter is closed, conversationId={}", conversationId);
|
||||
@@ -243,6 +249,54 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
sseEmitter.complete();
|
||||
}
|
||||
|
||||
private Map<String, Object> buildStreamMeta() {
|
||||
Map<String, Object> meta = new LinkedHashMap<>();
|
||||
BigInteger roundId = getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID);
|
||||
Integer roundNo = getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO);
|
||||
Integer variantIndex = getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX);
|
||||
BigInteger regenerateRoundId = getBigIntegerExt(ChatRuntimeExtKeys.REGENERATE_ROUND_ID);
|
||||
if (roundId != null) {
|
||||
meta.put("roundId", roundId.toString());
|
||||
}
|
||||
if (roundNo != null) {
|
||||
meta.put("roundNo", roundNo);
|
||||
}
|
||||
if (variantIndex != null) {
|
||||
meta.put("variantIndex", variantIndex);
|
||||
}
|
||||
meta.put("regenerate", regenerateRoundId != null);
|
||||
if (regenerateRoundId != null) {
|
||||
meta.put("regenerateRoundId", regenerateRoundId.toString());
|
||||
}
|
||||
return meta;
|
||||
}
|
||||
|
||||
private BigInteger getBigIntegerExt(String key) {
|
||||
Object value = runtimeContext == null || runtimeContext.getExt() == null
|
||||
? null
|
||||
: runtimeContext.getExt().get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value instanceof BigInteger number) {
|
||||
return number;
|
||||
}
|
||||
return new BigInteger(String.valueOf(value));
|
||||
}
|
||||
|
||||
private Integer getIntegerExt(String key) {
|
||||
Object value = runtimeContext == null || runtimeContext.getExt() == null
|
||||
? null
|
||||
: runtimeContext.getExt().get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value instanceof Integer number) {
|
||||
return number;
|
||||
}
|
||||
return Integer.parseInt(String.valueOf(value));
|
||||
}
|
||||
|
||||
private void stopStreamClient(StreamContext context, String reason, Throwable source) {
|
||||
try {
|
||||
if (context != null && context.getClient() != null) {
|
||||
@@ -267,6 +321,7 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
message.setCreatedAt(new Date());
|
||||
message.setSenderId(runtimeContext.getAssistantId());
|
||||
message.setSenderName(runtimeContext.getAssistantName());
|
||||
applyRoundMetadata(message);
|
||||
return message;
|
||||
}
|
||||
|
||||
@@ -280,9 +335,24 @@ public class ChatStreamListener implements StreamResponseListener {
|
||||
message.setCreatedAt(new Date());
|
||||
message.setSenderId(runtimeContext.getAssistantId());
|
||||
message.setSenderName(runtimeContext.getAssistantName());
|
||||
applyRoundMetadata(message);
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* 把当前轮次元数据写回 assistant 消息,确保 SSE 与后续持久化链路都能识别轮次和版本。
|
||||
*
|
||||
* @param message assistant 运行时消息
|
||||
*/
|
||||
private void applyRoundMetadata(ChatRuntimeMessage message) {
|
||||
if (message == null) {
|
||||
return;
|
||||
}
|
||||
message.setRoundId(getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID));
|
||||
message.setRoundNo(getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO));
|
||||
message.setVariantIndex(getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX));
|
||||
}
|
||||
|
||||
/**
|
||||
* 在 tool call assistant 写入临时 memory 前,把 reasoning/content 快照回填到消息对象中,
|
||||
* 以便前端 history 透传和 DeepSeek 下一轮请求都能拿到完整链路。
|
||||
|
||||
@@ -56,6 +56,25 @@ public interface BotService extends IService<Bot> {
|
||||
|
||||
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, BotServiceImpl.ChatCheckResult chatCheckResult);
|
||||
|
||||
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
BotServiceImpl.ChatCheckResult chatCheckResult, BigInteger regenerateRoundId);
|
||||
|
||||
/**
|
||||
* 聊天前置校验,并根据调用方要求决定是否强制走发布态。
|
||||
*
|
||||
* @param botId 聊天助手 ID
|
||||
* @param prompt 用户问题
|
||||
* @param conversationId 会话 ID
|
||||
* @param chatCheckResult 校验结果承载对象
|
||||
* @param publishedOnly 是否强制走发布态
|
||||
* @return 校验失败时返回错误 SSE;成功时返回 {@code null}
|
||||
*/
|
||||
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly);
|
||||
|
||||
SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId);
|
||||
|
||||
SseEmitter startChat(BigInteger botId, String prompt, BigInteger conversationId, List<Map<String, String>> messages,
|
||||
BotServiceImpl.ChatCheckResult chatCheckResult, List<String> attachments, ChatRuntimeContext runtimeContext);
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.easyagents.core.model.chat.ChatOptions;
|
||||
import com.easyagents.core.model.chat.StreamResponseListener;
|
||||
import com.easyagents.core.model.chat.tool.Tool;
|
||||
import com.easyagents.core.prompt.MemoryPrompt;
|
||||
import com.easyagents.rag.retrieval.RetrievalMode;
|
||||
import com.mybatisflex.core.query.QueryWrapper;
|
||||
import com.mybatisflex.spring.service.impl.ServiceImpl;
|
||||
import org.slf4j.Logger;
|
||||
@@ -51,6 +52,7 @@ import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseUtil;
|
||||
import tech.easyflow.core.runtime.ChatAssistantAccumulator;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeContext;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeManager;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeMessage;
|
||||
@@ -60,9 +62,11 @@ import javax.annotation.Resource;
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static tech.easyflow.ai.entity.table.BotPluginTableDef.BOT_PLUGIN;
|
||||
import static tech.easyflow.ai.entity.table.PluginItemTableDef.PLUGIN_ITEM;
|
||||
@@ -78,6 +82,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(BotServiceImpl.class);
|
||||
private static final String FAQ_IMAGE_SYSTEM_RULE = "当知识工具返回 Markdown 图片(格式:)时,你必须在最终回答中保留并输出对应的图片 Markdown,禁止改写、替换或省略图片 URL。";
|
||||
private static final String EXTRA_KNOWLEDGE_PRIORITY_RULE = "若当前会话显式选择了额外知识库,请优先参考这些额外知识库;仅在额外知识库无法回答时,再回退到助手默认绑定知识库。";
|
||||
private static final int MAX_EXTRA_KNOWLEDGE_COUNT = 3;
|
||||
|
||||
public static class ChatCheckResult {
|
||||
private Bot aiBot;
|
||||
@@ -215,6 +221,24 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
}
|
||||
|
||||
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, ChatCheckResult chatCheckResult) {
|
||||
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
ChatCheckResult chatCheckResult, BigInteger regenerateRoundId) {
|
||||
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, regenerateRoundId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
ChatCheckResult chatCheckResult, boolean publishedOnly) {
|
||||
return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, publishedOnly, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId,
|
||||
ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId) {
|
||||
if (!StringUtils.hasLength(prompt)) {
|
||||
return ChatSseUtil.sendSystemError(conversationId, "提示词不能为空");
|
||||
}
|
||||
@@ -248,7 +272,13 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
if ((!login || anonymousAccount) && !aiBot.isAnonymousEnabled()) {
|
||||
return ChatSseUtil.sendSystemError(conversationId, "此聊天助手不支持匿名访问");
|
||||
}
|
||||
if (!login || anonymousAccount) {
|
||||
if (publishedOnly && login && !anonymousAccount) {
|
||||
if (PublishStatus.from(aiBot.getPublishStatus()) != PublishStatus.PUBLISHED) {
|
||||
return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布");
|
||||
}
|
||||
aiBot = toPublishedView(aiBot);
|
||||
chatCheckResult.setPublishedAccess(true);
|
||||
} else if (!login || anonymousAccount) {
|
||||
Bot publishedBot = toPublishedView(aiBot);
|
||||
if (!PublishStatus.from(aiBot.getPublishStatus()).isExternallyVisible()) {
|
||||
return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布");
|
||||
@@ -267,7 +297,6 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
if (chatModel == null) {
|
||||
return ChatSseUtil.sendSystemError(conversationId, "对话模型获取失败,请检查配置");
|
||||
}
|
||||
|
||||
chatCheckResult.setAiBot(aiBot);
|
||||
chatCheckResult.setModelOptions(modelOptions);
|
||||
chatCheckResult.setChatModel(chatModel);
|
||||
@@ -282,8 +311,9 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
ChatModel chatModel = chatCheckResult.getChatModel();
|
||||
ChatTimeToolAvailabilityContext chatTimeContext = buildChatTimeToolAvailabilityContext(runtimeContext, chatCheckResult.getAiBot());
|
||||
final MemoryPrompt memoryPrompt = new MemoryPrompt();
|
||||
String systemPrompt = buildSystemPromptWithFaqImageRule(
|
||||
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT)
|
||||
String systemPrompt = buildSystemPrompt(
|
||||
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT),
|
||||
runtimeContext
|
||||
);
|
||||
Integer maxMessageCount = MapUtil.getInteger(modelOptions, Bot.KEY_MAX_MESSAGE_COUNT);
|
||||
if (maxMessageCount != null) {
|
||||
@@ -301,7 +331,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
.set("needEnglishName", true)
|
||||
.set("bot", chatCheckResult.getAiBot())
|
||||
.set("chatTimeContext", chatTimeContext)
|
||||
.set("publishedOnly", chatCheckResult.isPublishedAccess())));
|
||||
.set("publishedOnly", chatCheckResult.isPublishedAccess())
|
||||
.set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext))));
|
||||
ChatOptions chatOptions = getChatOptions(modelOptions);
|
||||
Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false);
|
||||
chatOptions.setThinkingEnabled(enableDeepThinking);
|
||||
@@ -353,8 +384,9 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false);
|
||||
chatOptions.setThinkingEnabled(enableDeepThinking);
|
||||
ChatModel chatModel = chatCheckResult.getChatModel();
|
||||
String systemPrompt = buildSystemPromptWithFaqImageRule(
|
||||
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT)
|
||||
String systemPrompt = buildSystemPrompt(
|
||||
MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT),
|
||||
runtimeContext
|
||||
);
|
||||
UserMessage userMessage = new UserMessage(prompt);
|
||||
userMessage.addTools(buildFunctionList(Maps.of("botId", botId)
|
||||
@@ -362,6 +394,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
.set("needAccountId", false)
|
||||
.set("bot", chatCheckResult.getAiBot())
|
||||
.set("publishedOnly", chatCheckResult.isPublishedAccess())
|
||||
.set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext))
|
||||
));
|
||||
ChatSseEmitter chatSseEmitter = new ChatSseEmitter();
|
||||
SseEmitter emitter = chatSseEmitter.getEmitter();
|
||||
@@ -474,15 +507,18 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
}
|
||||
Bot runtimeBot = (Bot) buildParams.get("bot");
|
||||
ChatTimeToolAvailabilityContext chatTimeContext = (ChatTimeToolAvailabilityContext) buildParams.get("chatTimeContext");
|
||||
List<BigInteger> extraKnowledgeIds = sanitizeExtraKnowledgeIds((List<BigInteger>) buildParams.get("extraKnowledgeIds"));
|
||||
boolean usePublishedSnapshot = Boolean.TRUE.equals(buildParams.get("publishedOnly"))
|
||||
&& runtimeBot != null
|
||||
&& runtimeBot.getPublishedSnapshotJson() != null
|
||||
&& PublishStatus.from(runtimeBot.getPublishStatus()).isExternallyVisible();
|
||||
|
||||
QueryWrapper queryWrapper = QueryWrapper.create();
|
||||
Set<BigInteger> existingKnowledgeIds = new LinkedHashSet<>();
|
||||
appendExtraKnowledgeTools(functionList, extraKnowledgeIds, needEnglishName, chatTimeContext, existingKnowledgeIds);
|
||||
if (usePublishedSnapshot) {
|
||||
appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext, existingKnowledgeIds);
|
||||
appendPublishedWorkflowTools(functionList, runtimeBot, needEnglishName);
|
||||
appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext);
|
||||
} else {
|
||||
// 工作流 function 集合
|
||||
queryWrapper.eq(BotWorkflow::getBotId, botId);
|
||||
@@ -500,7 +536,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
queryWrapper.eq(BotDocumentCollection::getBotId, botId);
|
||||
List<BotDocumentCollection> botDocumentCollections = botDocumentCollectionService.getMapper()
|
||||
.selectListWithRelationsByQuery(queryWrapper);
|
||||
functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext));
|
||||
functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, existingKnowledgeIds));
|
||||
}
|
||||
|
||||
// 插件 function 集合
|
||||
@@ -550,6 +586,22 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
List<Tool> buildKnowledgeTools(List<BotDocumentCollection> botDocumentCollections,
|
||||
boolean needEnglishName,
|
||||
ChatTimeToolAvailabilityContext chatTimeContext) {
|
||||
return buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, new LinkedHashSet<>());
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 Bot 绑定的知识库候选项收敛为当前聊天可用的工具列表,并与已选临时知识库去重。
|
||||
*
|
||||
* @param botDocumentCollections Bot 知识库绑定项
|
||||
* @param needEnglishName 是否使用英文名称
|
||||
* @param chatTimeContext 聊天时权限上下文
|
||||
* @param existingKnowledgeIds 已装配知识库 ID 集
|
||||
* @return 知识库工具列表
|
||||
*/
|
||||
List<Tool> buildKnowledgeTools(List<BotDocumentCollection> botDocumentCollections,
|
||||
boolean needEnglishName,
|
||||
ChatTimeToolAvailabilityContext chatTimeContext,
|
||||
Set<BigInteger> existingKnowledgeIds) {
|
||||
List<Tool> functionList = new ArrayList<>();
|
||||
if (botDocumentCollections == null || botDocumentCollections.isEmpty()) {
|
||||
return functionList;
|
||||
@@ -559,7 +611,7 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
: chatTimeToolAvailabilityService.filterAvailable(chatTimeContext, botDocumentCollections);
|
||||
for (BotDocumentCollection botDocumentCollection : availableBindings) {
|
||||
DocumentCollection knowledge = botDocumentCollection.getKnowledge();
|
||||
if (knowledge == null) {
|
||||
if (knowledge == null || isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) {
|
||||
continue;
|
||||
}
|
||||
DocumentCollectionTool function = (DocumentCollectionTool) knowledge.toFunction(
|
||||
@@ -603,7 +655,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
private void appendPublishedKnowledgeTools(List<Tool> functionList,
|
||||
Bot runtimeBot,
|
||||
boolean needEnglishName,
|
||||
ChatTimeToolAvailabilityContext chatTimeContext) {
|
||||
ChatTimeToolAvailabilityContext chatTimeContext,
|
||||
Set<BigInteger> existingKnowledgeIds) {
|
||||
Object knowledges = runtimeBot.getPublishedSnapshotJson().get("knowledgeBindings");
|
||||
if (!(knowledges instanceof List<?> knowledgeBindings)) {
|
||||
return;
|
||||
@@ -617,12 +670,15 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
continue;
|
||||
}
|
||||
DocumentCollection knowledge = documentCollectionService.getPublishedById(new BigInteger(String.valueOf(knowledgeId)));
|
||||
if (knowledge == null) {
|
||||
if (knowledge == null || PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) {
|
||||
continue;
|
||||
}
|
||||
if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) {
|
||||
continue;
|
||||
}
|
||||
if (isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) {
|
||||
continue;
|
||||
}
|
||||
Object retrievalMode = bindingMap.get("retrievalMode");
|
||||
functionList.add(knowledge.toFunction(
|
||||
needEnglishName,
|
||||
@@ -632,6 +688,42 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 组装会话级临时知识库工具,并按用户选择顺序优先插入。
|
||||
*
|
||||
* @param functionList 工具集合
|
||||
* @param extraKnowledgeIds 额外知识库 ID
|
||||
* @param needEnglishName 是否使用英文名称
|
||||
* @param chatTimeContext 聊天时权限上下文
|
||||
* @param existingKnowledgeIds 已装配知识库 ID 集
|
||||
*/
|
||||
protected void appendExtraKnowledgeTools(List<Tool> functionList,
|
||||
List<BigInteger> extraKnowledgeIds,
|
||||
boolean needEnglishName,
|
||||
ChatTimeToolAvailabilityContext chatTimeContext,
|
||||
Set<BigInteger> existingKnowledgeIds) {
|
||||
if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (BigInteger knowledgeId : extraKnowledgeIds) {
|
||||
if (knowledgeId == null || isDuplicateKnowledge(existingKnowledgeIds, knowledgeId)) {
|
||||
continue;
|
||||
}
|
||||
DocumentCollection knowledge = documentCollectionService.getById(knowledgeId);
|
||||
if (knowledge == null) {
|
||||
throw new BusinessException("额外知识库不存在");
|
||||
}
|
||||
if (PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) {
|
||||
throw new BusinessException("额外知识库未发布,无法用于聊天");
|
||||
}
|
||||
if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) {
|
||||
throw new BusinessException("当前用户无权使用所选知识库");
|
||||
}
|
||||
functionList.add(documentCollectionService.toPublishedView(knowledge)
|
||||
.toFunction(needEnglishName, RetrievalMode.HYBRID.name(), chatTimeContext));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构造聊天时工具可用性上下文,并显式回填到运行时上下文中供后续异步工具调用复用。
|
||||
*
|
||||
@@ -704,14 +796,64 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
|
||||
return context.getUserAccount();
|
||||
}
|
||||
|
||||
private String buildSystemPromptWithFaqImageRule(String systemPrompt) {
|
||||
if (!StringUtils.hasLength(systemPrompt)) {
|
||||
return FAQ_IMAGE_SYSTEM_RULE;
|
||||
private String buildSystemPrompt(String systemPrompt, ChatRuntimeContext runtimeContext) {
|
||||
String mergedPrompt = appendPromptRule(systemPrompt, FAQ_IMAGE_SYSTEM_RULE);
|
||||
if (hasExtraKnowledgeSelection(runtimeContext)) {
|
||||
mergedPrompt = appendPromptRule(mergedPrompt, EXTRA_KNOWLEDGE_PRIORITY_RULE);
|
||||
}
|
||||
if (systemPrompt.contains(FAQ_IMAGE_SYSTEM_RULE)) {
|
||||
return mergedPrompt;
|
||||
}
|
||||
|
||||
private String appendPromptRule(String systemPrompt, String rule) {
|
||||
if (!StringUtils.hasLength(systemPrompt)) {
|
||||
return rule;
|
||||
}
|
||||
if (systemPrompt.contains(rule)) {
|
||||
return systemPrompt;
|
||||
}
|
||||
return systemPrompt + "\n\n" + FAQ_IMAGE_SYSTEM_RULE;
|
||||
return systemPrompt + "\n\n" + rule;
|
||||
}
|
||||
|
||||
private List<BigInteger> resolveExtraKnowledgeIds(ChatRuntimeContext runtimeContext) {
|
||||
if (runtimeContext == null || runtimeContext.getExt() == null) {
|
||||
return List.of();
|
||||
}
|
||||
Object rawValue = runtimeContext.getExt().get(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS);
|
||||
if (!(rawValue instanceof List<?> rawList) || rawList.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
List<BigInteger> values = new ArrayList<>(rawList.size());
|
||||
for (Object item : rawList) {
|
||||
if (item == null) {
|
||||
continue;
|
||||
}
|
||||
values.add(new BigInteger(String.valueOf(item)));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
private List<BigInteger> sanitizeExtraKnowledgeIds(List<BigInteger> extraKnowledgeIds) {
|
||||
if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) {
|
||||
return List.of();
|
||||
}
|
||||
LinkedHashSet<BigInteger> dedup = new LinkedHashSet<>();
|
||||
for (BigInteger knowledgeId : extraKnowledgeIds) {
|
||||
if (knowledgeId != null) {
|
||||
dedup.add(knowledgeId);
|
||||
}
|
||||
}
|
||||
if (dedup.size() > MAX_EXTRA_KNOWLEDGE_COUNT) {
|
||||
throw new BusinessException("额外知识库最多选择 3 个");
|
||||
}
|
||||
return new ArrayList<>(dedup);
|
||||
}
|
||||
|
||||
private boolean hasExtraKnowledgeSelection(ChatRuntimeContext runtimeContext) {
|
||||
return !resolveExtraKnowledgeIds(runtimeContext).isEmpty();
|
||||
}
|
||||
|
||||
private boolean isDuplicateKnowledge(Set<BigInteger> existingKnowledgeIds, BigInteger knowledgeId) {
|
||||
return knowledgeId != null && existingKnowledgeIds != null && !existingKnowledgeIds.add(knowledgeId);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -121,11 +121,12 @@ public class BotServiceImplTest {
|
||||
List.class,
|
||||
Bot.class,
|
||||
boolean.class,
|
||||
ChatTimeToolAvailabilityContext.class
|
||||
ChatTimeToolAvailabilityContext.class,
|
||||
Set.class
|
||||
);
|
||||
method.setAccessible(true);
|
||||
|
||||
method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3));
|
||||
method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3), new LinkedHashSet<>());
|
||||
|
||||
Assert.assertEquals(1, functionList.size());
|
||||
DocumentCollectionTool tool = (DocumentCollectionTool) functionList.get(0);
|
||||
@@ -133,6 +134,102 @@ public class BotServiceImplTest {
|
||||
Assert.assertEquals(RetrievalMode.KEYWORD, tool.getRetrievalMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* 额外知识库应按用户选择顺序去重,并限制最多 3 个。
|
||||
*
|
||||
* @throws Exception 反射调用异常
|
||||
*/
|
||||
@Test
|
||||
@SuppressWarnings("unchecked")
|
||||
public void sanitizeExtraKnowledgeIdsShouldDeduplicateAndKeepOrder() throws Exception {
|
||||
BotServiceImpl service = new BotServiceImpl();
|
||||
Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class);
|
||||
method.setAccessible(true);
|
||||
|
||||
List<BigInteger> result = (List<BigInteger>) method.invoke(
|
||||
service,
|
||||
List.of(
|
||||
BigInteger.valueOf(3),
|
||||
BigInteger.valueOf(1),
|
||||
BigInteger.valueOf(3),
|
||||
BigInteger.valueOf(2)
|
||||
)
|
||||
);
|
||||
|
||||
Assert.assertEquals(
|
||||
List.of(BigInteger.valueOf(3), BigInteger.ONE, BigInteger.valueOf(2)),
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 额外知识库超过 3 个时应直接拒绝请求。
|
||||
*
|
||||
* @throws Exception 反射调用异常
|
||||
*/
|
||||
@Test
|
||||
public void sanitizeExtraKnowledgeIdsShouldRejectWhenMoreThanThree() throws Exception {
|
||||
BotServiceImpl service = new BotServiceImpl();
|
||||
Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class);
|
||||
method.setAccessible(true);
|
||||
|
||||
try {
|
||||
method.invoke(
|
||||
service,
|
||||
List.of(
|
||||
BigInteger.ONE,
|
||||
BigInteger.valueOf(2),
|
||||
BigInteger.valueOf(3),
|
||||
BigInteger.valueOf(4)
|
||||
)
|
||||
);
|
||||
Assert.fail("expected BusinessException");
|
||||
} catch (java.lang.reflect.InvocationTargetException exception) {
|
||||
Assert.assertTrue(exception.getTargetException() instanceof tech.easyflow.common.web.exceptions.BusinessException);
|
||||
Assert.assertEquals("额外知识库最多选择 3 个", exception.getTargetException().getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 额外知识库工具应强制使用发布态视图,并跳过重复选择。
|
||||
*
|
||||
* @throws Exception 反射注入异常
|
||||
*/
|
||||
@Test
|
||||
public void appendExtraKnowledgeToolsShouldUsePublishedKnowledgeAndDeduplicate() throws Exception {
|
||||
BotServiceImpl service = new BotServiceImpl();
|
||||
injectAvailabilityService(service, buildAvailabilityService(
|
||||
new RoleCategoryAccessSnapshot("KNOWLEDGE", BigInteger.valueOf(12), false, false, setOf(BigInteger.valueOf(21))),
|
||||
Collections.emptySet()
|
||||
));
|
||||
|
||||
DocumentCollection draftKnowledge = buildKnowledge(101, 11, 21, "PUBLIC");
|
||||
draftKnowledge.setPublishStatus("PUBLISHED");
|
||||
draftKnowledge.setTitle("draft-title");
|
||||
draftKnowledge.setDescription("draft-desc");
|
||||
|
||||
DocumentCollection publishedKnowledge = buildKnowledge(101, 11, 21, "PUBLIC");
|
||||
publishedKnowledge.setPublishStatus("PUBLISHED");
|
||||
publishedKnowledge.setTitle("published-title");
|
||||
publishedKnowledge.setDescription("published-desc");
|
||||
|
||||
injectDocumentCollectionService(service, mockDocumentCollectionServiceForExtra(draftKnowledge, publishedKnowledge));
|
||||
|
||||
List<Tool> tools = new ArrayList<>();
|
||||
service.appendExtraKnowledgeTools(
|
||||
tools,
|
||||
List.of(BigInteger.valueOf(101), BigInteger.valueOf(101)),
|
||||
false,
|
||||
buildContext(12, 3),
|
||||
new LinkedHashSet<>()
|
||||
);
|
||||
|
||||
Assert.assertEquals(1, tools.size());
|
||||
DocumentCollectionTool tool = (DocumentCollectionTool) tools.get(0);
|
||||
Assert.assertEquals(BigInteger.valueOf(101), tool.getKnowledgeId());
|
||||
Assert.assertEquals("published-desc", tool.getDescription());
|
||||
}
|
||||
|
||||
private void injectAvailabilityService(BotServiceImpl service, ChatTimeToolAvailabilityService availabilityService) throws Exception {
|
||||
Field field = BotServiceImpl.class.getDeclaredField("chatTimeToolAvailabilityService");
|
||||
field.setAccessible(true);
|
||||
@@ -189,6 +286,7 @@ public class BotServiceImplTest {
|
||||
knowledge.setCreatedBy(BigInteger.valueOf(createdBy));
|
||||
knowledge.setCategoryId(BigInteger.valueOf(categoryId));
|
||||
knowledge.setVisibilityScope(visibilityScope);
|
||||
knowledge.setPublishStatus("PUBLISHED");
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
@@ -216,6 +314,26 @@ public class BotServiceImplTest {
|
||||
);
|
||||
}
|
||||
|
||||
private tech.easyflow.ai.service.DocumentCollectionService mockDocumentCollectionServiceForExtra(DocumentCollection draftCollection,
|
||||
DocumentCollection publishedCollection) {
|
||||
return (tech.easyflow.ai.service.DocumentCollectionService) Proxy.newProxyInstance(
|
||||
tech.easyflow.ai.service.DocumentCollectionService.class.getClassLoader(),
|
||||
new Class<?>[]{tech.easyflow.ai.service.DocumentCollectionService.class},
|
||||
(proxy, method, args) -> {
|
||||
if ("getById".equals(method.getName())) {
|
||||
return draftCollection;
|
||||
}
|
||||
if ("toPublishedView".equals(method.getName())) {
|
||||
return publishedCollection;
|
||||
}
|
||||
if ("getPublishedById".equals(method.getName())) {
|
||||
return publishedCollection;
|
||||
}
|
||||
return defaultValue(method.getReturnType());
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
private CategoryPermissionService mockCategoryPermissionService(RoleCategoryAccessSnapshot accessSnapshot) {
|
||||
return (CategoryPermissionService) Proxy.newProxyInstance(
|
||||
CategoryPermissionService.class.getClassLoader(),
|
||||
|
||||
@@ -7,9 +7,13 @@ import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.config.ChatCacheProperties;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
@@ -61,6 +65,9 @@ public class ChatHotStateService {
|
||||
if (command.getTitle() != null && !command.getTitle().isBlank()) {
|
||||
summary.setTitle(command.getTitle());
|
||||
}
|
||||
if (command.getExtJson() != null && !command.getExtJson().isBlank()) {
|
||||
summary.setExtJson(command.getExtJson());
|
||||
}
|
||||
summary.setAccessAt(defaultDate(command.getOperateAt()));
|
||||
summary.setModified(defaultDate(command.getOperateAt()));
|
||||
summary.setModifiedBy(command.getOperatorId());
|
||||
@@ -119,9 +126,78 @@ public class ChatHotStateService {
|
||||
summary.setModified(defaultDate(command.getCreated()));
|
||||
summary.setModifiedBy(command.getCreatedBy());
|
||||
summary.setIsDeleted(0);
|
||||
summary.setMessageCount((summary.getMessageCount() == null ? 0 : summary.getMessageCount()) + 1);
|
||||
summary.setMessageCount((summary.getMessageCount() == null ? 0 : summary.getMessageCount())
|
||||
+ resolveVisibleMessageIncrement(command));
|
||||
cacheSessionSummaryStrict(summary);
|
||||
appendTail(toMessageRecord(command));
|
||||
appendVisibleTail(toMessageRecord(command));
|
||||
}
|
||||
|
||||
public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) {
|
||||
if (command == null || command.getSessionId() == null || command.getRoundId() == null) {
|
||||
return null;
|
||||
}
|
||||
ChatRoundRecord record = getRound(command.getSessionId(), command.getRoundId());
|
||||
if (record == null) {
|
||||
record = new ChatRoundRecord();
|
||||
record.setId(command.getRoundId());
|
||||
record.setSessionId(command.getSessionId());
|
||||
record.setCreated(defaultDate(command.getOperateAt()));
|
||||
}
|
||||
if (command.getRoundNo() != null) {
|
||||
record.setRoundNo(command.getRoundNo());
|
||||
}
|
||||
if (command.getUserMessageId() != null) {
|
||||
record.setUserMessageId(command.getUserMessageId());
|
||||
}
|
||||
if (command.getSelectedAssistantMessageId() != null) {
|
||||
record.setSelectedAssistantMessageId(command.getSelectedAssistantMessageId());
|
||||
}
|
||||
if (command.getSelectedVariantIndex() != null) {
|
||||
record.setSelectedVariantIndex(command.getSelectedVariantIndex());
|
||||
}
|
||||
if (command.getVariantCount() != null) {
|
||||
record.setVariantCount(command.getVariantCount());
|
||||
}
|
||||
if (command.getStatus() != null && !command.getStatus().isBlank()) {
|
||||
record.setStatus(command.getStatus());
|
||||
}
|
||||
record.setModified(defaultDate(command.getOperateAt()));
|
||||
cacheRoundStrict(record);
|
||||
syncTailRoundMeta(record);
|
||||
return record;
|
||||
}
|
||||
|
||||
public void selectVariant(ChatRoundSelectCommand command) {
|
||||
if (command == null || command.getSessionId() == null || command.getRoundId() == null) {
|
||||
return;
|
||||
}
|
||||
ChatRoundRecord record = getRound(command.getSessionId(), command.getRoundId());
|
||||
if (record == null) {
|
||||
return;
|
||||
}
|
||||
record.setSelectedAssistantMessageId(command.getSelectedAssistantMessageId());
|
||||
record.setSelectedVariantIndex(command.getSelectedVariantIndex());
|
||||
record.setModified(defaultDate(command.getOperateAt()));
|
||||
cacheRoundStrict(record);
|
||||
if (command.getSelectedAssistantMessage() != null) {
|
||||
applyRoundMeta(command.getSelectedAssistantMessage(), record);
|
||||
replaceSelectedAssistant(record, command.getSelectedAssistantMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public ChatRoundRecord getLatestRound(BigInteger sessionId) {
|
||||
return readValue(keyLatestRound(sessionId), ChatRoundRecord.class);
|
||||
}
|
||||
|
||||
public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) {
|
||||
return readValue(keyRound(sessionId, roundId), ChatRoundRecord.class);
|
||||
}
|
||||
|
||||
public void cacheRound(ChatRoundRecord record) {
|
||||
try {
|
||||
cacheRoundStrict(record);
|
||||
} catch (IllegalStateException ignored) {
|
||||
}
|
||||
}
|
||||
|
||||
public List<BigInteger> listSessionIds(BigInteger userId, long offset, long limit) {
|
||||
@@ -263,18 +339,168 @@ public class ChatHotStateService {
|
||||
}
|
||||
|
||||
public void appendTail(ChatMessageRecord record) {
|
||||
appendVisibleTail(record);
|
||||
}
|
||||
|
||||
public void appendVisibleTail(ChatMessageRecord record) {
|
||||
if (record == null || record.getSessionId() == null) {
|
||||
return;
|
||||
}
|
||||
ChatRoundRecord round = record.getRoundId() == null ? null : getRound(record.getSessionId(), record.getRoundId());
|
||||
applyRoundMeta(record, round);
|
||||
List<ChatMessageRecord> current = getSessionTail(record.getSessionId());
|
||||
List<ChatMessageRecord> updated = new ArrayList<>();
|
||||
updated.add(record);
|
||||
if (current != null) {
|
||||
updated.addAll(current);
|
||||
List<ChatMessageRecord> updated = current == null ? new ArrayList<>() : new ArrayList<>(current);
|
||||
removeExistingVisibleRecord(updated, record);
|
||||
insertSorted(updated, record);
|
||||
if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(record.getMessageKind())
|
||||
&& record.getVariantIndex() != null
|
||||
&& record.getVariantIndex() > 1) {
|
||||
removeOlderSelectedAssistant(updated, record);
|
||||
}
|
||||
writeValueStrict(keySessionTail(record.getSessionId()), trimTail(updated), properties.getSessionTailTtl());
|
||||
}
|
||||
|
||||
private void replaceSelectedAssistant(ChatRoundRecord round, ChatMessageRecord record) {
|
||||
List<ChatMessageRecord> current = getSessionTail(record.getSessionId());
|
||||
List<ChatMessageRecord> updated = current == null ? new ArrayList<>() : new ArrayList<>(current);
|
||||
syncRoundMeta(updated, round);
|
||||
removeExistingVisibleRecord(updated, record);
|
||||
removeOlderSelectedAssistant(updated, record);
|
||||
insertSorted(updated, record);
|
||||
writeValueStrict(keySessionTail(record.getSessionId()), trimTail(updated), properties.getSessionTailTtl());
|
||||
}
|
||||
|
||||
/**
|
||||
* 将轮次选中态同步到 Redis tail,避免前端主线过滤读到过期版本号。
|
||||
*
|
||||
* @param round 轮次记录
|
||||
*/
|
||||
private void syncTailRoundMeta(ChatRoundRecord round) {
|
||||
if (round == null || round.getSessionId() == null || round.getId() == null) {
|
||||
return;
|
||||
}
|
||||
List<ChatMessageRecord> current = getSessionTail(round.getSessionId());
|
||||
if (current == null || current.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
List<ChatMessageRecord> updated = new ArrayList<>(current);
|
||||
syncRoundMeta(updated, round);
|
||||
if (round.getSelectedAssistantMessageId() != null) {
|
||||
updated.removeIf(item -> item != null
|
||||
&& item.getRoundId() != null
|
||||
&& item.getRoundId().equals(round.getId())
|
||||
&& "assistant".equalsIgnoreCase(item.getSenderRole())
|
||||
&& item.getId() != null
|
||||
&& !item.getId().equals(round.getSelectedAssistantMessageId()));
|
||||
}
|
||||
writeValueStrict(keySessionTail(round.getSessionId()), trimTail(updated), properties.getSessionTailTtl());
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量同步同一轮次的版本元信息。
|
||||
*
|
||||
* @param records tail 消息列表
|
||||
* @param round 轮次记录
|
||||
*/
|
||||
private void syncRoundMeta(List<ChatMessageRecord> records, ChatRoundRecord round) {
|
||||
if (records == null || records.isEmpty() || round == null || round.getId() == null) {
|
||||
return;
|
||||
}
|
||||
for (ChatMessageRecord item : records) {
|
||||
if (item != null && round.getId().equals(item.getRoundId())) {
|
||||
applyRoundMeta(item, round);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将轮次元信息写入单条消息。
|
||||
*
|
||||
* @param record 消息记录
|
||||
* @param round 轮次记录
|
||||
*/
|
||||
private void applyRoundMeta(ChatMessageRecord record, ChatRoundRecord round) {
|
||||
if (record == null || round == null) {
|
||||
return;
|
||||
}
|
||||
if (round.getRoundNo() != null) {
|
||||
record.setRoundNo(round.getRoundNo());
|
||||
}
|
||||
if (round.getVariantCount() != null) {
|
||||
record.setVariantCount(round.getVariantCount());
|
||||
}
|
||||
if (round.getSelectedVariantIndex() != null) {
|
||||
record.setSelectedVariantIndex(round.getSelectedVariantIndex());
|
||||
}
|
||||
if (round.getStatus() != null) {
|
||||
record.setSwitchable(!ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus()));
|
||||
}
|
||||
}
|
||||
|
||||
private void removeOlderSelectedAssistant(List<ChatMessageRecord> records, ChatMessageRecord record) {
|
||||
if (record.getRoundId() == null) {
|
||||
return;
|
||||
}
|
||||
records.removeIf(item -> item != null
|
||||
&& item.getId() != null
|
||||
&& !item.getId().equals(record.getId())
|
||||
&& item.getRoundId() != null
|
||||
&& item.getRoundId().equals(record.getRoundId())
|
||||
&& "assistant".equalsIgnoreCase(item.getSenderRole()));
|
||||
}
|
||||
|
||||
private void removeExistingVisibleRecord(List<ChatMessageRecord> records, ChatMessageRecord record) {
|
||||
if (records == null || record == null || record.getId() == null) {
|
||||
return;
|
||||
}
|
||||
records.removeIf(item -> item != null && record.getId().equals(item.getId()));
|
||||
}
|
||||
|
||||
private void insertSorted(List<ChatMessageRecord> records, ChatMessageRecord record) {
|
||||
int insertIndex = 0;
|
||||
while (insertIndex < records.size()) {
|
||||
ChatMessageRecord current = records.get(insertIndex);
|
||||
if (shouldComeAfter(record, current)) {
|
||||
insertIndex++;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
records.add(insertIndex, record);
|
||||
}
|
||||
|
||||
private boolean shouldComeAfter(ChatMessageRecord candidate, ChatMessageRecord current) {
|
||||
if (candidate == null) {
|
||||
return true;
|
||||
}
|
||||
if (current == null) {
|
||||
return false;
|
||||
}
|
||||
Date candidateCreated = defaultDate(candidate.getCreated());
|
||||
Date currentCreated = defaultDate(current.getCreated());
|
||||
if (candidateCreated.before(currentCreated)) {
|
||||
return true;
|
||||
}
|
||||
if (candidateCreated.after(currentCreated)) {
|
||||
return false;
|
||||
}
|
||||
BigInteger candidateId = candidate.getId() == null ? BigInteger.ZERO : candidate.getId();
|
||||
BigInteger currentId = current.getId() == null ? BigInteger.ZERO : current.getId();
|
||||
return candidateId.compareTo(currentId) < 0;
|
||||
}
|
||||
|
||||
private void cacheRoundStrict(ChatRoundRecord record) {
|
||||
if (record == null || record.getSessionId() == null || record.getId() == null) {
|
||||
return;
|
||||
}
|
||||
writeValueStrict(keyRound(record.getSessionId(), record.getId()), record, properties.getSessionTailTtl());
|
||||
ChatRoundRecord latest = getLatestRound(record.getSessionId());
|
||||
if (latest == null || latest.getRoundNo() == null
|
||||
|| (record.getRoundNo() != null && record.getRoundNo() >= latest.getRoundNo())) {
|
||||
writeValueStrict(keyLatestRound(record.getSessionId()), record, properties.getSessionTailTtl());
|
||||
}
|
||||
}
|
||||
|
||||
public void evictSessionTail(BigInteger sessionId) {
|
||||
delete(keySessionTail(sessionId));
|
||||
}
|
||||
@@ -292,6 +518,10 @@ public class ChatHotStateService {
|
||||
record.setSenderId(command.getSenderId());
|
||||
record.setSenderName(command.getSenderName());
|
||||
record.setSenderRole(command.getSenderRole());
|
||||
record.setRoundId(command.getRoundId());
|
||||
record.setRoundNo(command.getRoundNo());
|
||||
record.setMessageKind(command.getMessageKind());
|
||||
record.setVariantIndex(command.getVariantIndex());
|
||||
record.setContentType(command.getContentType());
|
||||
record.setContentText(command.getContentText());
|
||||
record.setContentPayload(command.getContentPayload());
|
||||
@@ -301,6 +531,17 @@ public class ChatHotStateService {
|
||||
return record;
|
||||
}
|
||||
|
||||
private int resolveVisibleMessageIncrement(ChatAppendMessageCommand command) {
|
||||
if (command == null) {
|
||||
return 0;
|
||||
}
|
||||
if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(command.getMessageKind())) {
|
||||
Integer variantIndex = command.getVariantIndex();
|
||||
return variantIndex != null && variantIndex > 1 ? 0 : 1;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
private List<ChatMessageRecord> trimTail(List<ChatMessageRecord> records) {
|
||||
if (records == null || records.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
@@ -326,6 +567,14 @@ public class ChatHotStateService {
|
||||
return "chat:session:tail:" + sessionId;
|
||||
}
|
||||
|
||||
private String keyLatestRound(BigInteger sessionId) {
|
||||
return "chat:session:round:latest:" + sessionId;
|
||||
}
|
||||
|
||||
private String keyRound(BigInteger sessionId, BigInteger roundId) {
|
||||
return "chat:session:round:" + sessionId + ":" + roundId;
|
||||
}
|
||||
|
||||
private void removeFromSessionIndex(BigInteger userId, BigInteger sessionId) {
|
||||
if (userId == null || sessionId == null) {
|
||||
return;
|
||||
|
||||
@@ -23,6 +23,10 @@ public class ChatAppendMessageCommand implements Serializable {
|
||||
private String contentType;
|
||||
private String contentText;
|
||||
private Map<String, Object> contentPayload;
|
||||
private BigInteger roundId;
|
||||
private Integer roundNo;
|
||||
private String messageKind;
|
||||
private Integer variantIndex;
|
||||
private BigInteger createdBy;
|
||||
private Date created = new Date();
|
||||
|
||||
@@ -154,6 +158,38 @@ public class ChatAppendMessageCommand implements Serializable {
|
||||
this.contentPayload = contentPayload;
|
||||
}
|
||||
|
||||
public BigInteger getRoundId() {
|
||||
return roundId;
|
||||
}
|
||||
|
||||
public void setRoundId(BigInteger roundId) {
|
||||
this.roundId = roundId;
|
||||
}
|
||||
|
||||
public Integer getRoundNo() {
|
||||
return roundNo;
|
||||
}
|
||||
|
||||
public void setRoundNo(Integer roundNo) {
|
||||
this.roundNo = roundNo;
|
||||
}
|
||||
|
||||
public String getMessageKind() {
|
||||
return messageKind;
|
||||
}
|
||||
|
||||
public void setMessageKind(String messageKind) {
|
||||
this.messageKind = messageKind;
|
||||
}
|
||||
|
||||
public Integer getVariantIndex() {
|
||||
return variantIndex;
|
||||
}
|
||||
|
||||
public void setVariantIndex(Integer variantIndex) {
|
||||
this.variantIndex = variantIndex;
|
||||
}
|
||||
|
||||
public BigInteger getCreatedBy() {
|
||||
return createdBy;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package tech.easyflow.chatlog.domain.command;
|
||||
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
|
||||
/**
|
||||
* 轮次答案版本切换命令。
|
||||
*/
|
||||
public class ChatRoundSelectCommand implements Serializable {
|
||||
|
||||
private BigInteger sessionId;
|
||||
private BigInteger roundId;
|
||||
private Integer selectedVariantIndex;
|
||||
private BigInteger selectedAssistantMessageId;
|
||||
private ChatMessageRecord selectedAssistantMessage;
|
||||
private BigInteger operatorId;
|
||||
private Date operateAt = new Date();
|
||||
|
||||
public BigInteger getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public void setSessionId(BigInteger sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public BigInteger getRoundId() {
|
||||
return roundId;
|
||||
}
|
||||
|
||||
public void setRoundId(BigInteger roundId) {
|
||||
this.roundId = roundId;
|
||||
}
|
||||
|
||||
public Integer getSelectedVariantIndex() {
|
||||
return selectedVariantIndex;
|
||||
}
|
||||
|
||||
public void setSelectedVariantIndex(Integer selectedVariantIndex) {
|
||||
this.selectedVariantIndex = selectedVariantIndex;
|
||||
}
|
||||
|
||||
public BigInteger getSelectedAssistantMessageId() {
|
||||
return selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) {
|
||||
this.selectedAssistantMessageId = selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public ChatMessageRecord getSelectedAssistantMessage() {
|
||||
return selectedAssistantMessage;
|
||||
}
|
||||
|
||||
public void setSelectedAssistantMessage(ChatMessageRecord selectedAssistantMessage) {
|
||||
this.selectedAssistantMessage = selectedAssistantMessage;
|
||||
}
|
||||
|
||||
public BigInteger getOperatorId() {
|
||||
return operatorId;
|
||||
}
|
||||
|
||||
public void setOperatorId(BigInteger operatorId) {
|
||||
this.operatorId = operatorId;
|
||||
}
|
||||
|
||||
public Date getOperateAt() {
|
||||
return operateAt;
|
||||
}
|
||||
|
||||
public void setOperateAt(Date operateAt) {
|
||||
this.operateAt = operateAt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package tech.easyflow.chatlog.domain.command;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
|
||||
/**
|
||||
* 轮次聚合写入命令。
|
||||
*/
|
||||
public class ChatRoundUpsertCommand implements Serializable {
|
||||
|
||||
private BigInteger roundId;
|
||||
private BigInteger sessionId;
|
||||
private Integer roundNo;
|
||||
private BigInteger userMessageId;
|
||||
private BigInteger selectedAssistantMessageId;
|
||||
private Integer selectedVariantIndex;
|
||||
private Integer variantCount;
|
||||
private String status;
|
||||
private BigInteger operatorId;
|
||||
private Date operateAt = new Date();
|
||||
|
||||
public BigInteger getRoundId() {
|
||||
return roundId;
|
||||
}
|
||||
|
||||
public void setRoundId(BigInteger roundId) {
|
||||
this.roundId = roundId;
|
||||
}
|
||||
|
||||
public BigInteger getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public void setSessionId(BigInteger sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public Integer getRoundNo() {
|
||||
return roundNo;
|
||||
}
|
||||
|
||||
public void setRoundNo(Integer roundNo) {
|
||||
this.roundNo = roundNo;
|
||||
}
|
||||
|
||||
public BigInteger getUserMessageId() {
|
||||
return userMessageId;
|
||||
}
|
||||
|
||||
public void setUserMessageId(BigInteger userMessageId) {
|
||||
this.userMessageId = userMessageId;
|
||||
}
|
||||
|
||||
public BigInteger getSelectedAssistantMessageId() {
|
||||
return selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) {
|
||||
this.selectedAssistantMessageId = selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public Integer getSelectedVariantIndex() {
|
||||
return selectedVariantIndex;
|
||||
}
|
||||
|
||||
public void setSelectedVariantIndex(Integer selectedVariantIndex) {
|
||||
this.selectedVariantIndex = selectedVariantIndex;
|
||||
}
|
||||
|
||||
public Integer getVariantCount() {
|
||||
return variantCount;
|
||||
}
|
||||
|
||||
public void setVariantCount(Integer variantCount) {
|
||||
this.variantCount = variantCount;
|
||||
}
|
||||
|
||||
public String getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
public void setStatus(String status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
public BigInteger getOperatorId() {
|
||||
return operatorId;
|
||||
}
|
||||
|
||||
public void setOperatorId(BigInteger operatorId) {
|
||||
this.operatorId = operatorId;
|
||||
}
|
||||
|
||||
public Date getOperateAt() {
|
||||
return operateAt;
|
||||
}
|
||||
|
||||
public void setOperateAt(Date operateAt) {
|
||||
this.operateAt = operateAt;
|
||||
}
|
||||
}
|
||||
@@ -12,8 +12,11 @@ public class ChatSessionSummaryCommand implements Serializable {
|
||||
private String lastSenderName;
|
||||
private String lastMessagePreview;
|
||||
private Date lastMessageAt = new Date();
|
||||
private Date accessAt = new Date();
|
||||
private Date modifiedAt = new Date();
|
||||
private BigInteger operatorId;
|
||||
private int messageIncrement = 1;
|
||||
private boolean forceOverwrite;
|
||||
|
||||
public BigInteger getSessionId() {
|
||||
return sessionId;
|
||||
@@ -63,6 +66,22 @@ public class ChatSessionSummaryCommand implements Serializable {
|
||||
this.lastMessageAt = lastMessageAt;
|
||||
}
|
||||
|
||||
public Date getAccessAt() {
|
||||
return accessAt;
|
||||
}
|
||||
|
||||
public void setAccessAt(Date accessAt) {
|
||||
this.accessAt = accessAt;
|
||||
}
|
||||
|
||||
public Date getModifiedAt() {
|
||||
return modifiedAt;
|
||||
}
|
||||
|
||||
public void setModifiedAt(Date modifiedAt) {
|
||||
this.modifiedAt = modifiedAt;
|
||||
}
|
||||
|
||||
public BigInteger getOperatorId() {
|
||||
return operatorId;
|
||||
}
|
||||
@@ -78,4 +97,12 @@ public class ChatSessionSummaryCommand implements Serializable {
|
||||
public void setMessageIncrement(int messageIncrement) {
|
||||
this.messageIncrement = Math.max(messageIncrement, 0);
|
||||
}
|
||||
|
||||
public boolean isForceOverwrite() {
|
||||
return forceOverwrite;
|
||||
}
|
||||
|
||||
public void setForceOverwrite(boolean forceOverwrite) {
|
||||
this.forceOverwrite = forceOverwrite;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class ChatSessionUpsertCommand implements Serializable {
|
||||
private String assistantCode;
|
||||
private String assistantName;
|
||||
private String title;
|
||||
private String extJson;
|
||||
private BigInteger operatorId;
|
||||
private Date operateAt = new Date();
|
||||
|
||||
@@ -90,6 +91,14 @@ public class ChatSessionUpsertCommand implements Serializable {
|
||||
this.title = title;
|
||||
}
|
||||
|
||||
public String getExtJson() {
|
||||
return extJson;
|
||||
}
|
||||
|
||||
public void setExtJson(String extJson) {
|
||||
this.extJson = extJson;
|
||||
}
|
||||
|
||||
public BigInteger getOperatorId() {
|
||||
return operatorId;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,13 @@ public class ChatMessageRecord implements Serializable {
|
||||
private String contentType;
|
||||
private String contentText;
|
||||
private Map<String, Object> contentPayload;
|
||||
private BigInteger roundId;
|
||||
private Integer roundNo;
|
||||
private String messageKind;
|
||||
private Integer variantIndex;
|
||||
private Integer variantCount;
|
||||
private Integer selectedVariantIndex;
|
||||
private Boolean switchable;
|
||||
private Date created;
|
||||
private BigInteger createdBy;
|
||||
private Long syncVersion;
|
||||
@@ -101,6 +108,62 @@ public class ChatMessageRecord implements Serializable {
|
||||
this.contentPayload = contentPayload;
|
||||
}
|
||||
|
||||
public BigInteger getRoundId() {
|
||||
return roundId;
|
||||
}
|
||||
|
||||
public void setRoundId(BigInteger roundId) {
|
||||
this.roundId = roundId;
|
||||
}
|
||||
|
||||
public Integer getRoundNo() {
|
||||
return roundNo;
|
||||
}
|
||||
|
||||
public void setRoundNo(Integer roundNo) {
|
||||
this.roundNo = roundNo;
|
||||
}
|
||||
|
||||
public String getMessageKind() {
|
||||
return messageKind;
|
||||
}
|
||||
|
||||
public void setMessageKind(String messageKind) {
|
||||
this.messageKind = messageKind;
|
||||
}
|
||||
|
||||
public Integer getVariantIndex() {
|
||||
return variantIndex;
|
||||
}
|
||||
|
||||
public void setVariantIndex(Integer variantIndex) {
|
||||
this.variantIndex = variantIndex;
|
||||
}
|
||||
|
||||
public Integer getVariantCount() {
|
||||
return variantCount;
|
||||
}
|
||||
|
||||
public void setVariantCount(Integer variantCount) {
|
||||
this.variantCount = variantCount;
|
||||
}
|
||||
|
||||
public Integer getSelectedVariantIndex() {
|
||||
return selectedVariantIndex;
|
||||
}
|
||||
|
||||
public void setSelectedVariantIndex(Integer selectedVariantIndex) {
|
||||
this.selectedVariantIndex = selectedVariantIndex;
|
||||
}
|
||||
|
||||
public Boolean getSwitchable() {
|
||||
return switchable;
|
||||
}
|
||||
|
||||
public void setSwitchable(Boolean switchable) {
|
||||
this.switchable = switchable;
|
||||
}
|
||||
|
||||
public Date getCreated() {
|
||||
return created;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
package tech.easyflow.chatlog.domain.dto;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
|
||||
/**
|
||||
* 聊天轮次记录。
|
||||
*/
|
||||
public class ChatRoundRecord implements Serializable {
|
||||
|
||||
private BigInteger id;
|
||||
private BigInteger sessionId;
|
||||
private Integer roundNo;
|
||||
private BigInteger userMessageId;
|
||||
private BigInteger selectedAssistantMessageId;
|
||||
private Integer selectedVariantIndex;
|
||||
private Integer variantCount;
|
||||
private String status;
|
||||
private Date created;
|
||||
private Date modified;
|
||||
|
||||
public BigInteger getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(BigInteger id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public BigInteger getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public void setSessionId(BigInteger sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public Integer getRoundNo() {
|
||||
return roundNo;
|
||||
}
|
||||
|
||||
public void setRoundNo(Integer roundNo) {
|
||||
this.roundNo = roundNo;
|
||||
}
|
||||
|
||||
public BigInteger getUserMessageId() {
|
||||
return userMessageId;
|
||||
}
|
||||
|
||||
public void setUserMessageId(BigInteger userMessageId) {
|
||||
this.userMessageId = userMessageId;
|
||||
}
|
||||
|
||||
public BigInteger getSelectedAssistantMessageId() {
|
||||
return selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) {
|
||||
this.selectedAssistantMessageId = selectedAssistantMessageId;
|
||||
}
|
||||
|
||||
public Integer getSelectedVariantIndex() {
|
||||
return selectedVariantIndex;
|
||||
}
|
||||
|
||||
public void setSelectedVariantIndex(Integer selectedVariantIndex) {
|
||||
this.selectedVariantIndex = selectedVariantIndex;
|
||||
}
|
||||
|
||||
public Integer getVariantCount() {
|
||||
return variantCount;
|
||||
}
|
||||
|
||||
public void setVariantCount(Integer variantCount) {
|
||||
this.variantCount = variantCount;
|
||||
}
|
||||
|
||||
public String getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
public void setStatus(String status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
public Date getCreated() {
|
||||
return created;
|
||||
}
|
||||
|
||||
public void setCreated(Date created) {
|
||||
this.created = created;
|
||||
}
|
||||
|
||||
public Date getModified() {
|
||||
return modified;
|
||||
}
|
||||
|
||||
public void setModified(Date modified) {
|
||||
this.modified = modified;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package tech.easyflow.chatlog.domain.dto;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 会话扩展载荷。
|
||||
*/
|
||||
public class ChatSessionExtPayload implements Serializable {
|
||||
|
||||
private List<BigInteger> extraKnowledgeIds = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* 获取会话级额外知识库 ID 列表。
|
||||
*
|
||||
* @return 额外知识库 ID 列表
|
||||
*/
|
||||
public List<BigInteger> getExtraKnowledgeIds() {
|
||||
return extraKnowledgeIds;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置会话级额外知识库 ID 列表。
|
||||
*
|
||||
* @param extraKnowledgeIds 额外知识库 ID 列表
|
||||
*/
|
||||
public void setExtraKnowledgeIds(List<BigInteger> extraKnowledgeIds) {
|
||||
this.extraKnowledgeIds = extraKnowledgeIds == null ? new ArrayList<>() : new ArrayList<>(extraKnowledgeIds);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ public class ChatSessionSummary implements Serializable {
|
||||
private String assistantCode;
|
||||
private String assistantName;
|
||||
private String title;
|
||||
private String extJson;
|
||||
private String lastMessagePreview;
|
||||
private BigInteger lastSenderId;
|
||||
private String lastSenderName;
|
||||
@@ -99,6 +100,14 @@ public class ChatSessionSummary implements Serializable {
|
||||
this.title = title;
|
||||
}
|
||||
|
||||
public String getExtJson() {
|
||||
return extJson;
|
||||
}
|
||||
|
||||
public void setExtJson(String extJson) {
|
||||
this.extJson = extJson;
|
||||
}
|
||||
|
||||
public String getLastMessagePreview() {
|
||||
return lastMessagePreview;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package tech.easyflow.chatlog.domain.event;
|
||||
public enum ChatPersistEventType {
|
||||
|
||||
SESSION_PREPARED,
|
||||
ROUND_UPSERTED,
|
||||
ROUND_VARIANT_SELECTED,
|
||||
USER_MESSAGE_APPENDED,
|
||||
ASSISTANT_MESSAGE_APPENDED,
|
||||
SESSION_RENAMED,
|
||||
|
||||
@@ -5,6 +5,7 @@ import org.springframework.stereotype.Repository;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
import tech.easyflow.chatlog.support.ChatTableRouter;
|
||||
|
||||
import java.math.BigInteger;
|
||||
@@ -49,8 +50,8 @@ public class MySqlChatLogRepository {
|
||||
for (Map.Entry<YearMonth, List<ChatAppendMessageCommand>> entry : grouped.entrySet()) {
|
||||
String table = tableRouter.resolveLogTable(entry.getKey());
|
||||
String sql = "INSERT IGNORE INTO `" + table + "` " +
|
||||
"(id, tenant_id, dept_id, session_id, user_id, assistant_id, sender_id, sender_name, sender_role, content_type, content_text, content_payload, created, created_by, sync_version) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
||||
"(id, tenant_id, dept_id, session_id, user_id, assistant_id, round_id, round_no, sender_id, sender_name, sender_role, message_kind, variant_index, content_type, content_text, content_payload, created, created_by, sync_version) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
||||
int[] results = jdbcTemplate.batchUpdate(sql, new org.springframework.jdbc.core.BatchPreparedStatementSetter() {
|
||||
@Override
|
||||
public void setValues(java.sql.PreparedStatement ps, int i) throws java.sql.SQLException {
|
||||
@@ -62,15 +63,19 @@ public class MySqlChatLogRepository {
|
||||
ps.setObject(4, command.getSessionId());
|
||||
ps.setObject(5, command.getUserId());
|
||||
ps.setObject(6, command.getAssistantId());
|
||||
ps.setObject(7, command.getSenderId());
|
||||
ps.setString(8, command.getSenderName());
|
||||
ps.setString(9, command.getSenderRole());
|
||||
ps.setString(10, command.getContentType());
|
||||
ps.setString(11, command.getContentText());
|
||||
ps.setString(12, jsonSupport.toJson(command.getContentPayload()));
|
||||
ps.setTimestamp(13, created);
|
||||
ps.setObject(14, command.getCreatedBy());
|
||||
ps.setLong(15, command.getCreated().getTime());
|
||||
ps.setObject(7, command.getRoundId());
|
||||
ps.setObject(8, command.getRoundNo());
|
||||
ps.setObject(9, command.getSenderId());
|
||||
ps.setString(10, command.getSenderName());
|
||||
ps.setString(11, command.getSenderRole());
|
||||
ps.setString(12, command.getMessageKind());
|
||||
ps.setObject(13, command.getVariantIndex());
|
||||
ps.setString(14, command.getContentType());
|
||||
ps.setString(15, command.getContentText());
|
||||
ps.setString(16, jsonSupport.toJson(command.getContentPayload()));
|
||||
ps.setTimestamp(17, created);
|
||||
ps.setObject(18, command.getCreatedBy());
|
||||
ps.setLong(19, command.getCreated().getTime());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -92,7 +97,14 @@ public class MySqlChatLogRepository {
|
||||
for (YearMonth month : months) {
|
||||
String table = tableRouter.resolveLogTable(month);
|
||||
List<ChatMessageRecord> current = jdbcTemplate.query(
|
||||
"SELECT * FROM `" + table + "` WHERE session_id=? ORDER BY created DESC, id DESC LIMIT ?",
|
||||
"SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " +
|
||||
"CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " +
|
||||
"AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " +
|
||||
"THEN 1 ELSE 0 END AS switchable " +
|
||||
"FROM `" + table + "` l " +
|
||||
"LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " +
|
||||
"WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " +
|
||||
"ORDER BY l.created DESC, l.id DESC LIMIT ?",
|
||||
(rs, rowNum) -> mapRow(rs),
|
||||
sessionId,
|
||||
limit
|
||||
@@ -114,6 +126,159 @@ public class MySqlChatLogRepository {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 MySQL 热表分页查询主线可见消息。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param months 查询月份
|
||||
* @param offset 分页偏移
|
||||
* @param limit 分页条数
|
||||
* @return 主线消息列表,按 created desc、id desc 排序
|
||||
*/
|
||||
public List<ChatMessageRecord> listMainlineMessages(BigInteger sessionId, List<YearMonth> months, long offset, int limit) {
|
||||
if (sessionId == null || months == null || months.isEmpty() || limit <= 0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
int candidateLimit = resolveCandidateLimit(offset, limit);
|
||||
Map<BigInteger, ChatMessageRecord> recordMap = new LinkedHashMap<>();
|
||||
for (YearMonth month : months) {
|
||||
String table = tableRouter.resolveLogTable(month);
|
||||
List<ChatMessageRecord> current = jdbcTemplate.query(
|
||||
"SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " +
|
||||
"CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " +
|
||||
"AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " +
|
||||
"THEN 1 ELSE 0 END AS switchable " +
|
||||
"FROM `" + table + "` l " +
|
||||
"LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " +
|
||||
"WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " +
|
||||
"ORDER BY l.created DESC, l.id DESC LIMIT ?",
|
||||
(rs, rowNum) -> mapRow(rs),
|
||||
sessionId,
|
||||
candidateLimit
|
||||
);
|
||||
for (ChatMessageRecord record : current) {
|
||||
if (record != null && record.getId() != null) {
|
||||
recordMap.putIfAbsent(record.getId(), record);
|
||||
}
|
||||
}
|
||||
}
|
||||
return recordMap.values().stream()
|
||||
.sorted((a, b) -> {
|
||||
int compare = b.getCreated().compareTo(a.getCreated());
|
||||
if (compare != 0) {
|
||||
return compare;
|
||||
}
|
||||
return b.getId().compareTo(a.getId());
|
||||
})
|
||||
.skip(Math.max(offset, 0))
|
||||
.limit(limit)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 MySQL 热表查询全部主线可见消息。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param months 查询月份
|
||||
* @return 主线消息列表,按 created asc、id asc 排序
|
||||
*/
|
||||
public List<ChatMessageRecord> listMainlineMessages(BigInteger sessionId, List<YearMonth> months) {
|
||||
if (sessionId == null || months == null || months.isEmpty()) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
Map<BigInteger, ChatMessageRecord> recordMap = new LinkedHashMap<>();
|
||||
for (YearMonth month : months) {
|
||||
String table = tableRouter.resolveLogTable(month);
|
||||
List<ChatMessageRecord> current = jdbcTemplate.query(
|
||||
"SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " +
|
||||
"CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " +
|
||||
"AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " +
|
||||
"THEN 1 ELSE 0 END AS switchable " +
|
||||
"FROM `" + table + "` l " +
|
||||
"LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " +
|
||||
"WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " +
|
||||
"ORDER BY l.created ASC, l.id ASC",
|
||||
(rs, rowNum) -> mapRow(rs),
|
||||
sessionId
|
||||
);
|
||||
for (ChatMessageRecord record : current) {
|
||||
if (record != null && record.getId() != null) {
|
||||
recordMap.putIfAbsent(record.getId(), record);
|
||||
}
|
||||
}
|
||||
}
|
||||
return recordMap.values().stream()
|
||||
.sorted((a, b) -> {
|
||||
int compare = a.getCreated().compareTo(b.getCreated());
|
||||
if (compare != 0) {
|
||||
return compare;
|
||||
}
|
||||
return a.getId().compareTo(b.getId());
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<ChatMessageRecord> listRoundVariants(BigInteger sessionId, BigInteger roundId, List<YearMonth> months) {
|
||||
List<ChatMessageRecord> records = new ArrayList<>();
|
||||
for (YearMonth month : months) {
|
||||
String table = tableRouter.resolveLogTable(month);
|
||||
records.addAll(jdbcTemplate.query(
|
||||
"SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " +
|
||||
"0 AS switchable " +
|
||||
"FROM `" + table + "` l " +
|
||||
"INNER JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " +
|
||||
"WHERE l.session_id=? AND l.round_id=? AND l.message_kind=? " +
|
||||
"ORDER BY l.variant_index ASC, l.created ASC, l.id ASC",
|
||||
(rs, rowNum) -> mapRow(rs),
|
||||
sessionId,
|
||||
roundId,
|
||||
ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT
|
||||
));
|
||||
}
|
||||
return records;
|
||||
}
|
||||
|
||||
/**
|
||||
* 精准查询轮次下指定答案版本。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @param variantIndex 答案版本序号
|
||||
* @param months 查询月份
|
||||
* @return 目标答案版本
|
||||
*/
|
||||
public ChatMessageRecord findRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, List<YearMonth> months) {
|
||||
if (sessionId == null || roundId == null || variantIndex == null || variantIndex <= 0 || months == null || months.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
for (YearMonth month : months) {
|
||||
String table = tableRouter.resolveLogTable(month);
|
||||
List<ChatMessageRecord> records = jdbcTemplate.query(
|
||||
"SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, 0 AS switchable " +
|
||||
"FROM `" + table + "` l " +
|
||||
"INNER JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " +
|
||||
"WHERE l.session_id=? AND l.round_id=? AND l.message_kind=? AND l.variant_index=? " +
|
||||
"ORDER BY l.created DESC, l.id DESC LIMIT 1",
|
||||
(rs, rowNum) -> mapRow(rs),
|
||||
sessionId,
|
||||
roundId,
|
||||
ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT,
|
||||
variantIndex
|
||||
);
|
||||
if (!records.isEmpty()) {
|
||||
return records.get(0);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private int resolveCandidateLimit(long offset, int limit) {
|
||||
long normalizedOffset = Math.max(offset, 0);
|
||||
long normalizedLimit = Math.max(limit, 1);
|
||||
long candidateLimit = normalizedOffset + normalizedLimit;
|
||||
return (int) Math.min(candidateLimit, Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
public List<ChatMessageRecord> loadIncremental(String table, Date cursorTime, BigInteger cursorId, int limit) {
|
||||
Timestamp timestamp = cursorTime == null ? new Timestamp(0L) : new Timestamp(cursorTime.getTime());
|
||||
return jdbcTemplate.query(
|
||||
@@ -149,9 +314,20 @@ public class MySqlChatLogRepository {
|
||||
record.setSessionId(bigInteger(rs, "session_id"));
|
||||
record.setUserId(bigInteger(rs, "user_id"));
|
||||
record.setAssistantId(bigInteger(rs, "assistant_id"));
|
||||
record.setRoundId(bigInteger(rs, "round_id"));
|
||||
record.setRoundNo(optionalInteger(rs, "round_no"));
|
||||
Integer joinedRoundNo = optionalInteger(rs, "joined_round_no");
|
||||
if (joinedRoundNo != null) {
|
||||
record.setRoundNo(joinedRoundNo);
|
||||
}
|
||||
record.setSenderId(bigInteger(rs, "sender_id"));
|
||||
record.setSenderName(rs.getString("sender_name"));
|
||||
record.setSenderRole(rs.getString("sender_role"));
|
||||
record.setMessageKind(optionalString(rs, "message_kind"));
|
||||
record.setVariantIndex(optionalInteger(rs, "variant_index"));
|
||||
record.setVariantCount(optionalInteger(rs, "variant_count"));
|
||||
record.setSelectedVariantIndex(optionalInteger(rs, "selected_variant_index"));
|
||||
record.setSwitchable(optionalBoolean(rs, "switchable"));
|
||||
record.setContentType(rs.getString("content_type"));
|
||||
record.setContentText(rs.getString("content_text"));
|
||||
record.setContentPayload(jsonSupport.toMap(rs.getString("content_payload")));
|
||||
@@ -168,4 +344,39 @@ public class MySqlChatLogRepository {
|
||||
}
|
||||
return new BigInteger(String.valueOf(value));
|
||||
}
|
||||
|
||||
private Integer optionalInteger(ResultSet rs, String column) {
|
||||
try {
|
||||
Object value = rs.getObject(column);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
return Integer.parseInt(String.valueOf(value));
|
||||
} catch (SQLException ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Boolean optionalBoolean(ResultSet rs, String column) {
|
||||
try {
|
||||
Object value = rs.getObject(column);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value instanceof Boolean booleanValue) {
|
||||
return booleanValue;
|
||||
}
|
||||
return Integer.parseInt(String.valueOf(value)) != 0;
|
||||
} catch (SQLException ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private String optionalString(ResultSet rs, String column) {
|
||||
try {
|
||||
return rs.getString(column);
|
||||
} catch (SQLException ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package tech.easyflow.chatlog.repository.mysql;
|
||||
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.jdbc.core.RowMapper;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* MySQL 轮次仓储。
|
||||
*/
|
||||
@Repository
|
||||
public class MySqlChatRoundRepository {
|
||||
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
public MySqlChatRoundRepository(JdbcTemplate jdbcTemplate) {
|
||||
this.jdbcTemplate = jdbcTemplate;
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量写入或更新轮次聚合。
|
||||
*
|
||||
* @param commands 轮次命令
|
||||
*/
|
||||
public void createOrTouchBatch(List<ChatRoundUpsertCommand> commands) {
|
||||
if (commands == null || commands.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String sql = "INSERT INTO `" + ChatConstants.ROUND_TABLE + "` " +
|
||||
"(id, session_id, round_no, user_message_id, selected_assistant_message_id, selected_variant_index, variant_count, status, created, modified) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " +
|
||||
"ON DUPLICATE KEY UPDATE user_message_id=COALESCE(VALUES(user_message_id), user_message_id), " +
|
||||
"selected_assistant_message_id=CASE " +
|
||||
"WHEN VALUES(selected_assistant_message_id) IS NOT NULL AND (selected_assistant_message_id IS NULL OR VALUES(modified) > modified) " +
|
||||
"THEN VALUES(selected_assistant_message_id) ELSE selected_assistant_message_id END, " +
|
||||
"selected_variant_index=CASE " +
|
||||
"WHEN VALUES(selected_variant_index) IS NOT NULL AND (selected_variant_index IS NULL OR selected_variant_index = 0 OR VALUES(modified) > modified) " +
|
||||
"THEN VALUES(selected_variant_index) ELSE selected_variant_index END, " +
|
||||
"variant_count=GREATEST(COALESCE(VALUES(variant_count), 0), COALESCE(variant_count, 0)), " +
|
||||
"status=CASE " +
|
||||
"WHEN VALUES(status) IS NULL OR VALUES(status) = '' THEN status " +
|
||||
"WHEN VALUES(modified) > modified THEN VALUES(status) " +
|
||||
"WHEN VALUES(modified) = modified AND status = '" + ChatConstants.ROUND_STATUS_ANSWERING + "' " +
|
||||
"AND VALUES(status) <> '" + ChatConstants.ROUND_STATUS_ANSWERING + "' THEN VALUES(status) " +
|
||||
"WHEN VALUES(modified) = modified AND status <> '" + ChatConstants.ROUND_STATUS_LOCKED + "' " +
|
||||
"AND VALUES(status) = '" + ChatConstants.ROUND_STATUS_LOCKED + "' THEN VALUES(status) " +
|
||||
"ELSE status END, " +
|
||||
"modified=GREATEST(VALUES(modified), modified)";
|
||||
jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> {
|
||||
Timestamp operateAt = timestamp(command.getOperateAt());
|
||||
ps.setObject(1, command.getRoundId());
|
||||
ps.setObject(2, command.getSessionId());
|
||||
ps.setObject(3, command.getRoundNo());
|
||||
ps.setObject(4, command.getUserMessageId());
|
||||
ps.setObject(5, command.getSelectedAssistantMessageId());
|
||||
ps.setObject(6, command.getSelectedVariantIndex());
|
||||
ps.setObject(7, command.getVariantCount());
|
||||
ps.setString(8, command.getStatus());
|
||||
ps.setTimestamp(9, operateAt);
|
||||
ps.setTimestamp(10, operateAt);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 应用答案版本切换。
|
||||
*
|
||||
* @param commands 切换命令
|
||||
*/
|
||||
public void selectVariants(List<ChatRoundSelectCommand> commands) {
|
||||
if (commands == null || commands.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String sql = "UPDATE `" + ChatConstants.ROUND_TABLE + "` SET selected_assistant_message_id=?, selected_variant_index=?, modified=? " +
|
||||
"WHERE id=? AND session_id=? AND modified <= ?";
|
||||
jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> {
|
||||
Timestamp operateAt = timestamp(command.getOperateAt());
|
||||
ps.setObject(1, command.getSelectedAssistantMessageId());
|
||||
ps.setObject(2, command.getSelectedVariantIndex());
|
||||
ps.setTimestamp(3, operateAt);
|
||||
ps.setObject(4, command.getRoundId());
|
||||
ps.setObject(5, command.getSessionId());
|
||||
ps.setTimestamp(6, operateAt);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询指定轮次。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @return 轮次记录
|
||||
*/
|
||||
public ChatRoundRecord findRound(BigInteger sessionId, BigInteger roundId) {
|
||||
List<ChatRoundRecord> records = jdbcTemplate.query(
|
||||
"SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=? AND id=? LIMIT 1",
|
||||
rowMapper(),
|
||||
sessionId,
|
||||
roundId
|
||||
);
|
||||
return records.isEmpty() ? null : records.get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询会话最新轮次。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @return 最新轮次
|
||||
*/
|
||||
public ChatRoundRecord findLatestRound(BigInteger sessionId) {
|
||||
List<ChatRoundRecord> records = jdbcTemplate.query(
|
||||
"SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=? ORDER BY round_no DESC, id DESC LIMIT 1",
|
||||
rowMapper(),
|
||||
sessionId
|
||||
);
|
||||
return records.isEmpty() ? null : records.get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断会话是否已使用轮次模型。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @return 是否存在轮次
|
||||
*/
|
||||
public boolean existsRounds(BigInteger sessionId) {
|
||||
Long count = jdbcTemplate.queryForObject(
|
||||
"SELECT COUNT(1) FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=?",
|
||||
Long.class,
|
||||
sessionId
|
||||
);
|
||||
return count != null && count > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取指定时间之后变更的轮次。
|
||||
*
|
||||
* @param cursorTime 游标时间
|
||||
* @param cursorId 游标 ID
|
||||
* @param limit 批大小
|
||||
* @return 轮次记录列表
|
||||
*/
|
||||
public List<ChatRoundRecord> loadModifiedAfter(Date cursorTime, BigInteger cursorId, int limit) {
|
||||
Timestamp timestamp = cursorTime == null ? new Timestamp(0L) : new Timestamp(cursorTime.getTime());
|
||||
return jdbcTemplate.query(
|
||||
"SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE (modified > ?) OR (modified = ? AND id > ?) " +
|
||||
"ORDER BY modified ASC, id ASC LIMIT ?",
|
||||
rowMapper(),
|
||||
timestamp,
|
||||
timestamp,
|
||||
cursorId == null ? BigInteger.ZERO : cursorId,
|
||||
limit
|
||||
);
|
||||
}
|
||||
|
||||
private RowMapper<ChatRoundRecord> rowMapper() {
|
||||
return new RowMapper<>() {
|
||||
@Override
|
||||
public ChatRoundRecord mapRow(ResultSet rs, int rowNum) throws SQLException {
|
||||
ChatRoundRecord record = new ChatRoundRecord();
|
||||
record.setId(bigInteger(rs, "id"));
|
||||
record.setSessionId(bigInteger(rs, "session_id"));
|
||||
record.setRoundNo(rs.getInt("round_no"));
|
||||
record.setUserMessageId(bigInteger(rs, "user_message_id"));
|
||||
record.setSelectedAssistantMessageId(bigInteger(rs, "selected_assistant_message_id"));
|
||||
record.setSelectedVariantIndex(rs.getInt("selected_variant_index"));
|
||||
record.setVariantCount(rs.getInt("variant_count"));
|
||||
record.setStatus(rs.getString("status"));
|
||||
record.setCreated(rs.getTimestamp("created"));
|
||||
record.setModified(rs.getTimestamp("modified"));
|
||||
return record;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private BigInteger bigInteger(ResultSet rs, String column) throws SQLException {
|
||||
Object value = rs.getObject(column);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
return new BigInteger(String.valueOf(value));
|
||||
}
|
||||
|
||||
private Timestamp timestamp(Date value) {
|
||||
return new Timestamp((value == null ? new Date() : value).getTime());
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import java.math.BigInteger;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.sql.Types;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.LinkedHashMap;
|
||||
@@ -39,11 +40,12 @@ public class MySqlChatSessionRepository {
|
||||
}
|
||||
String table = tableRouter.resolveSessionTable();
|
||||
String sql = "INSERT INTO `" + table + "` " +
|
||||
"(id, tenant_id, dept_id, user_id, user_account, assistant_id, assistant_code, assistant_name, title, last_message_preview, message_count, access_at, created, created_by, modified, modified_by, is_deleted) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, '', 0, ?, ?, ?, ?, ?, 0) " +
|
||||
"(id, tenant_id, dept_id, user_id, user_account, assistant_id, assistant_code, assistant_name, title, ext_json, last_message_preview, message_count, access_at, created, created_by, modified, modified_by, is_deleted) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', 0, ?, ?, ?, ?, ?, 0) " +
|
||||
"ON DUPLICATE KEY UPDATE user_account=COALESCE(NULLIF(VALUES(user_account), ''), user_account), " +
|
||||
"assistant_id=VALUES(assistant_id), assistant_code=COALESCE(NULLIF(VALUES(assistant_code), ''), assistant_code), " +
|
||||
"assistant_name=COALESCE(NULLIF(VALUES(assistant_name), ''), assistant_name), " +
|
||||
"ext_json=COALESCE(VALUES(ext_json), ext_json), " +
|
||||
"title=COALESCE(NULLIF(VALUES(title), ''), title), " +
|
||||
"access_at=VALUES(access_at), modified=VALUES(modified), modified_by=VALUES(modified_by), is_deleted=0";
|
||||
jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> {
|
||||
@@ -57,11 +59,12 @@ public class MySqlChatSessionRepository {
|
||||
ps.setString(7, safeString(command.getAssistantCode()));
|
||||
ps.setString(8, safeString(command.getAssistantName()));
|
||||
ps.setString(9, safeString(command.getTitle()));
|
||||
ps.setTimestamp(10, operateAt);
|
||||
setNullableJson(ps, 10, command.getExtJson());
|
||||
ps.setTimestamp(11, operateAt);
|
||||
ps.setObject(12, command.getOperatorId());
|
||||
ps.setTimestamp(13, operateAt);
|
||||
ps.setObject(14, command.getOperatorId());
|
||||
ps.setTimestamp(12, operateAt);
|
||||
ps.setObject(13, command.getOperatorId());
|
||||
ps.setTimestamp(14, operateAt);
|
||||
ps.setObject(15, command.getOperatorId());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -71,33 +74,38 @@ public class MySqlChatSessionRepository {
|
||||
}
|
||||
String table = tableRouter.resolveSessionTable();
|
||||
String sql = "UPDATE `" + table + "` SET " +
|
||||
"last_sender_id=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_id END, " +
|
||||
"last_sender_name=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_name END, " +
|
||||
"last_message_preview=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_preview END, " +
|
||||
"last_message_at=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_at END, " +
|
||||
"access_at=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE access_at END, " +
|
||||
"last_sender_id=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_id END, " +
|
||||
"last_sender_name=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_name END, " +
|
||||
"last_message_preview=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_preview END, " +
|
||||
"last_message_at=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_at END, " +
|
||||
"access_at=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE access_at END, " +
|
||||
"message_count=COALESCE(message_count, 0) + ?, " +
|
||||
"modified=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE modified END, " +
|
||||
"modified_by=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE modified_by END " +
|
||||
"modified=?, modified_by=? " +
|
||||
"WHERE id=?";
|
||||
jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> {
|
||||
Timestamp lastMessageAt = timestamp(command.getLastMessageAt());
|
||||
ps.setTimestamp(1, lastMessageAt);
|
||||
ps.setObject(2, command.getLastSenderId());
|
||||
ps.setTimestamp(3, lastMessageAt);
|
||||
ps.setString(4, command.getLastSenderName());
|
||||
Timestamp accessAt = timestamp(command.getAccessAt());
|
||||
Timestamp modifiedAt = timestamp(command.getModifiedAt());
|
||||
int forceOverwrite = command.isForceOverwrite() ? 1 : 0;
|
||||
ps.setInt(1, forceOverwrite);
|
||||
ps.setTimestamp(2, lastMessageAt);
|
||||
ps.setObject(3, command.getLastSenderId());
|
||||
ps.setInt(4, forceOverwrite);
|
||||
ps.setTimestamp(5, lastMessageAt);
|
||||
ps.setString(6, command.getLastMessagePreview());
|
||||
ps.setTimestamp(7, lastMessageAt);
|
||||
ps.setString(6, command.getLastSenderName());
|
||||
ps.setInt(7, forceOverwrite);
|
||||
ps.setTimestamp(8, lastMessageAt);
|
||||
ps.setTimestamp(9, lastMessageAt);
|
||||
ps.setTimestamp(10, lastMessageAt);
|
||||
ps.setInt(11, Math.max(command.getMessageIncrement(), 1));
|
||||
ps.setString(9, safeString(command.getLastMessagePreview()));
|
||||
ps.setInt(10, forceOverwrite);
|
||||
ps.setTimestamp(11, lastMessageAt);
|
||||
ps.setTimestamp(12, lastMessageAt);
|
||||
ps.setTimestamp(13, lastMessageAt);
|
||||
ps.setInt(13, forceOverwrite);
|
||||
ps.setTimestamp(14, lastMessageAt);
|
||||
ps.setObject(15, command.getOperatorId());
|
||||
ps.setObject(16, command.getSessionId());
|
||||
ps.setTimestamp(15, accessAt);
|
||||
ps.setInt(16, Math.max(command.getMessageIncrement(), 0));
|
||||
ps.setTimestamp(17, modifiedAt);
|
||||
ps.setObject(18, command.getOperatorId());
|
||||
ps.setObject(19, command.getSessionId());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -105,13 +113,13 @@ public class MySqlChatSessionRepository {
|
||||
String table = tableRouter.resolveSessionTable();
|
||||
List<Object> params = new ArrayList<>();
|
||||
StringBuilder sql = new StringBuilder("SELECT * FROM `").append(table)
|
||||
.append("` WHERE user_id=? AND is_deleted=0");
|
||||
.append("` WHERE user_id=? AND is_deleted=0 AND last_message_at IS NOT NULL");
|
||||
params.add(userId);
|
||||
if (assistantId != null) {
|
||||
sql.append(" AND assistant_id=?");
|
||||
params.add(assistantId);
|
||||
}
|
||||
sql.append(" ORDER BY access_at DESC, id DESC LIMIT ? OFFSET ?");
|
||||
sql.append(" ORDER BY last_message_at DESC, id DESC LIMIT ? OFFSET ?");
|
||||
params.add(query.getPageSize());
|
||||
params.add(query.getOffset());
|
||||
return jdbcTemplate.query(sql.toString(), sessionRowMapper(), params.toArray());
|
||||
@@ -121,7 +129,7 @@ public class MySqlChatSessionRepository {
|
||||
String table = tableRouter.resolveSessionTable();
|
||||
List<Object> params = new ArrayList<>();
|
||||
StringBuilder sql = new StringBuilder("SELECT COUNT(1) FROM `").append(table)
|
||||
.append("` WHERE user_id=? AND is_deleted=0");
|
||||
.append("` WHERE user_id=? AND is_deleted=0 AND last_message_at IS NOT NULL");
|
||||
params.add(userId);
|
||||
if (assistantId != null) {
|
||||
sql.append(" AND assistant_id=?");
|
||||
@@ -248,6 +256,7 @@ public class MySqlChatSessionRepository {
|
||||
summary.setAssistantCode(rs.getString("assistant_code"));
|
||||
summary.setAssistantName(rs.getString("assistant_name"));
|
||||
summary.setTitle(rs.getString("title"));
|
||||
summary.setExtJson(rs.getString("ext_json"));
|
||||
summary.setLastMessagePreview(rs.getString("last_message_preview"));
|
||||
summary.setLastSenderId(bigInteger(rs, "last_sender_id"));
|
||||
summary.setLastSenderName(rs.getString("last_sender_name"));
|
||||
@@ -279,4 +288,12 @@ public class MySqlChatSessionRepository {
|
||||
private String safeString(String value) {
|
||||
return value == null ? "" : value;
|
||||
}
|
||||
|
||||
private void setNullableJson(java.sql.PreparedStatement ps, int parameterIndex, String value) throws SQLException {
|
||||
if (value == null || value.isBlank()) {
|
||||
ps.setNull(parameterIndex, Types.VARCHAR);
|
||||
return;
|
||||
}
|
||||
ps.setString(parameterIndex, value);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import org.slf4j.MDC;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.cache.ChatHotStateService;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.chatlog.domain.event.ChatPersistEvent;
|
||||
import tech.easyflow.chatlog.domain.event.ChatPersistEventType;
|
||||
import tech.easyflow.chatlog.domain.event.payload.ChatSessionDeletePayload;
|
||||
import tech.easyflow.chatlog.domain.event.payload.ChatSessionRenamePayload;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
@@ -19,21 +25,26 @@ import java.util.UUID;
|
||||
@Service
|
||||
public class ChatPersistDispatcher {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ChatPersistDispatcher.class);
|
||||
|
||||
private final ChatHotStateService chatHotStateService;
|
||||
private final ChatPersistEventProducer eventProducer;
|
||||
private final ChatPersistMySqlApplyService mySqlApplyService;
|
||||
private final ChatJsonSupport chatJsonSupport;
|
||||
|
||||
public ChatPersistDispatcher(ChatHotStateService chatHotStateService,
|
||||
ChatPersistEventProducer eventProducer,
|
||||
ChatPersistMySqlApplyService mySqlApplyService,
|
||||
ChatJsonSupport chatJsonSupport) {
|
||||
this.chatHotStateService = chatHotStateService;
|
||||
this.eventProducer = eventProducer;
|
||||
this.mySqlApplyService = mySqlApplyService;
|
||||
this.chatJsonSupport = chatJsonSupport;
|
||||
}
|
||||
|
||||
public ChatSessionSummary createOrTouchSession(ChatSessionUpsertCommand command) {
|
||||
ChatSessionSummary summary = chatHotStateService.touchSession(command);
|
||||
eventProducer.send(buildEvent(
|
||||
ChatPersistEvent event = buildEvent(
|
||||
UUID.randomUUID().toString(),
|
||||
ChatPersistEventType.SESSION_PREPARED,
|
||||
command.getSessionId(),
|
||||
@@ -41,10 +52,43 @@ public class ChatPersistDispatcher {
|
||||
command.getAssistantId(),
|
||||
command.getOperateAt(),
|
||||
chatJsonSupport.toJson(command)
|
||||
));
|
||||
);
|
||||
persistImmediately(event);
|
||||
eventProducer.send(event);
|
||||
return summary;
|
||||
}
|
||||
|
||||
public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) {
|
||||
ChatRoundRecord record = chatHotStateService.createOrTouchRound(command);
|
||||
ChatPersistEvent event = buildEvent(
|
||||
UUID.randomUUID().toString(),
|
||||
ChatPersistEventType.ROUND_UPSERTED,
|
||||
command.getSessionId(),
|
||||
BigInteger.ZERO,
|
||||
BigInteger.ZERO,
|
||||
command.getOperateAt(),
|
||||
chatJsonSupport.toJson(command)
|
||||
);
|
||||
persistImmediately(event);
|
||||
eventProducer.send(event);
|
||||
return record;
|
||||
}
|
||||
|
||||
public void selectRoundVariant(ChatRoundSelectCommand command) {
|
||||
chatHotStateService.selectVariant(command);
|
||||
ChatPersistEvent event = buildEvent(
|
||||
UUID.randomUUID().toString(),
|
||||
ChatPersistEventType.ROUND_VARIANT_SELECTED,
|
||||
command.getSessionId(),
|
||||
BigInteger.ZERO,
|
||||
BigInteger.ZERO,
|
||||
command.getOperateAt(),
|
||||
chatJsonSupport.toJson(command)
|
||||
);
|
||||
persistImmediately(event);
|
||||
eventProducer.send(event);
|
||||
}
|
||||
|
||||
public void appendUserMessage(ChatAppendMessageCommand command) {
|
||||
appendMessage(command, ChatPersistEventType.USER_MESSAGE_APPENDED);
|
||||
}
|
||||
@@ -96,7 +140,7 @@ public class ChatPersistDispatcher {
|
||||
|
||||
private void appendMessage(ChatAppendMessageCommand command, ChatPersistEventType eventType) {
|
||||
chatHotStateService.appendMessage(command);
|
||||
eventProducer.send(buildEvent(
|
||||
ChatPersistEvent event = buildEvent(
|
||||
eventId("message", command.getMessageId()),
|
||||
eventType,
|
||||
command.getSessionId(),
|
||||
@@ -104,7 +148,27 @@ public class ChatPersistDispatcher {
|
||||
command.getAssistantId(),
|
||||
command.getCreated(),
|
||||
chatJsonSupport.toJson(command)
|
||||
));
|
||||
);
|
||||
persistImmediately(event);
|
||||
eventProducer.send(event);
|
||||
}
|
||||
|
||||
/**
|
||||
* 先同步写入 MySQL,再发送异步事件,保证会话列表和版本切换读取有确定来源。
|
||||
*
|
||||
* @param event 持久化事件
|
||||
*/
|
||||
private void persistImmediately(ChatPersistEvent event) {
|
||||
try {
|
||||
mySqlApplyService.apply(java.util.List.of(event));
|
||||
} catch (RuntimeException ex) {
|
||||
log.error("聊天记录同步写入 MySQL 失败,eventId={}, eventType={}, sessionId={}",
|
||||
event == null ? null : event.getEventId(),
|
||||
event == null ? null : event.getEventType(),
|
||||
event == null ? null : event.getSessionId(),
|
||||
ex);
|
||||
throw new BusinessException("聊天记录持久化失败,请稍后重试");
|
||||
}
|
||||
}
|
||||
|
||||
private ChatPersistEvent buildEvent(String eventId,
|
||||
|
||||
@@ -3,6 +3,8 @@ package tech.easyflow.chatlog.service;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionSummaryCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.event.ChatPersistEvent;
|
||||
@@ -11,7 +13,9 @@ import tech.easyflow.chatlog.domain.event.payload.ChatSessionDeletePayload;
|
||||
import tech.easyflow.chatlog.domain.event.payload.ChatSessionRenamePayload;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
|
||||
import java.math.BigInteger;
|
||||
@@ -30,15 +34,18 @@ public class ChatPersistMySqlApplyService {
|
||||
|
||||
private final MySqlChatSessionRepository sessionRepository;
|
||||
private final MySqlChatLogRepository logRepository;
|
||||
private final MySqlChatRoundRepository roundRepository;
|
||||
private final MySqlChatLogTableManager tableManager;
|
||||
private final ChatJsonSupport chatJsonSupport;
|
||||
|
||||
public ChatPersistMySqlApplyService(MySqlChatSessionRepository sessionRepository,
|
||||
MySqlChatLogRepository logRepository,
|
||||
MySqlChatRoundRepository roundRepository,
|
||||
MySqlChatLogTableManager tableManager,
|
||||
ChatJsonSupport chatJsonSupport) {
|
||||
this.sessionRepository = sessionRepository;
|
||||
this.logRepository = logRepository;
|
||||
this.roundRepository = roundRepository;
|
||||
this.tableManager = tableManager;
|
||||
this.chatJsonSupport = chatJsonSupport;
|
||||
}
|
||||
@@ -50,6 +57,8 @@ public class ChatPersistMySqlApplyService {
|
||||
}
|
||||
|
||||
Map<BigInteger, ChatSessionUpsertCommand> sessionUpserts = new LinkedHashMap<>();
|
||||
Map<BigInteger, ChatRoundUpsertCommand> roundUpserts = new LinkedHashMap<>();
|
||||
List<ChatRoundSelectCommand> roundSelections = new ArrayList<>();
|
||||
List<ChatAppendMessageCommand> appendCommands = new ArrayList<>();
|
||||
Map<BigInteger, ChatSessionSummaryCommand> summaryCommands = new LinkedHashMap<>();
|
||||
List<ChatSessionRenamePayload> renamePayloads = new ArrayList<>();
|
||||
@@ -67,6 +76,18 @@ public class ChatPersistMySqlApplyService {
|
||||
sessionUpserts.put(command.getSessionId(), command);
|
||||
}
|
||||
}
|
||||
case ROUND_UPSERTED -> {
|
||||
ChatRoundUpsertCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatRoundUpsertCommand.class);
|
||||
if (command != null && command.getRoundId() != null) {
|
||||
roundUpserts.put(command.getRoundId(), command);
|
||||
}
|
||||
}
|
||||
case ROUND_VARIANT_SELECTED -> {
|
||||
ChatRoundSelectCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatRoundSelectCommand.class);
|
||||
if (command != null && command.getRoundId() != null) {
|
||||
roundSelections.add(command);
|
||||
}
|
||||
}
|
||||
case USER_MESSAGE_APPENDED, ASSISTANT_MESSAGE_APPENDED -> {
|
||||
ChatAppendMessageCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatAppendMessageCommand.class);
|
||||
if (command == null || command.getSessionId() == null || command.getMessageId() == null) {
|
||||
@@ -96,6 +117,9 @@ public class ChatPersistMySqlApplyService {
|
||||
if (!sessionUpserts.isEmpty()) {
|
||||
sessionRepository.createOrTouchBatch(new ArrayList<>(sessionUpserts.values()));
|
||||
}
|
||||
if (!roundUpserts.isEmpty()) {
|
||||
roundRepository.createOrTouchBatch(new ArrayList<>(roundUpserts.values()));
|
||||
}
|
||||
if (!months.isEmpty()) {
|
||||
for (YearMonth month : months) {
|
||||
tableManager.ensureMonthTable(month);
|
||||
@@ -113,6 +137,9 @@ public class ChatPersistMySqlApplyService {
|
||||
}
|
||||
sessionRepository.updateSummaries(new ArrayList<>(summaryCommands.values()));
|
||||
}
|
||||
if (!roundSelections.isEmpty()) {
|
||||
roundRepository.selectVariants(roundSelections);
|
||||
}
|
||||
if (!renamePayloads.isEmpty()) {
|
||||
sessionRepository.renameSessions(renamePayloads);
|
||||
}
|
||||
@@ -127,15 +154,26 @@ public class ChatPersistMySqlApplyService {
|
||||
ChatSessionSummaryCommand created = new ChatSessionSummaryCommand();
|
||||
created.setSessionId(command.getSessionId());
|
||||
created.setUserId(command.getUserId());
|
||||
created.setLastMessageAt(null);
|
||||
created.setAccessAt(null);
|
||||
created.setModifiedAt(null);
|
||||
created.setMessageIncrement(0);
|
||||
return created;
|
||||
});
|
||||
summary.setMessageIncrement(summary.getMessageIncrement() + 1);
|
||||
if (summary.getLastMessageAt() == null || !command.getCreated().before(summary.getLastMessageAt())) {
|
||||
if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(command.getMessageKind())
|
||||
&& command.getVariantIndex() != null
|
||||
&& command.getVariantIndex() > 1) {
|
||||
summary.setMessageIncrement(Math.max(summary.getMessageIncrement() - 1, 0));
|
||||
}
|
||||
Date commandCreated = defaultDate(command.getCreated());
|
||||
if (summary.getLastMessageAt() == null || !commandCreated.before(summary.getLastMessageAt())) {
|
||||
summary.setLastSenderId(command.getSenderId());
|
||||
summary.setLastSenderName(command.getSenderName());
|
||||
summary.setLastMessagePreview(trimPreview(command.getContentText()));
|
||||
summary.setLastMessageAt(command.getCreated());
|
||||
summary.setLastMessageAt(commandCreated);
|
||||
summary.setAccessAt(commandCreated);
|
||||
summary.setModifiedAt(commandCreated);
|
||||
summary.setOperatorId(command.getCreatedBy());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
|
||||
/**
|
||||
* 聊天轮次写服务。
|
||||
*/
|
||||
public interface ChatRoundCommandService {
|
||||
|
||||
/**
|
||||
* 创建或更新轮次聚合。
|
||||
*
|
||||
* @param command 轮次命令
|
||||
* @return 最新轮次记录
|
||||
*/
|
||||
ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command);
|
||||
|
||||
/**
|
||||
* 切换轮次当前选中的答案版本。
|
||||
*
|
||||
* @param command 切换命令
|
||||
*/
|
||||
void selectVariant(ChatRoundSelectCommand command);
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 聊天轮次业务操作服务。
|
||||
*/
|
||||
public interface ChatRoundOperateService {
|
||||
|
||||
/**
|
||||
* 校验并返回允许重答的轮次。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @return 轮次记录
|
||||
*/
|
||||
ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId);
|
||||
|
||||
/**
|
||||
* 查询轮次下所有答案版本。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @return 答案版本列表
|
||||
*/
|
||||
List<ChatMessageRecord> listVariants(BigInteger sessionId, BigInteger roundId);
|
||||
|
||||
/**
|
||||
* 切换指定轮次当前选中的答案版本。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @param variantIndex 目标版本序号
|
||||
* @param operatorId 操作人
|
||||
* @return 选中的答案消息
|
||||
*/
|
||||
ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId);
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 聊天轮次读服务。
|
||||
*/
|
||||
public interface ChatRoundQueryService {
|
||||
|
||||
/**
|
||||
* 查询会话最新轮次。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @return 最新轮次
|
||||
*/
|
||||
ChatRoundRecord getLatestRound(BigInteger sessionId);
|
||||
|
||||
/**
|
||||
* 查询指定轮次。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @return 轮次记录
|
||||
*/
|
||||
ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId);
|
||||
|
||||
/**
|
||||
* 查询轮次下所有助手答案版本。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @return 答案版本列表
|
||||
*/
|
||||
List<ChatMessageRecord> listRoundVariants(BigInteger sessionId, BigInteger roundId);
|
||||
|
||||
/**
|
||||
* 查询轮次下指定答案版本。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param roundId 轮次 ID
|
||||
* @param variantIndex 答案版本序号
|
||||
* @return 答案版本记录
|
||||
*/
|
||||
ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex);
|
||||
|
||||
/**
|
||||
* 判断会话是否已经启用轮次模型。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @return 是否存在轮次
|
||||
*/
|
||||
boolean hasRounds(BigInteger sessionId);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatHistoryPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.chatlog.domain.query.ChatPageQuery;
|
||||
@@ -18,5 +19,22 @@ public interface ChatSessionQueryService {
|
||||
|
||||
ChatSessionSummary getSessionSummary(BigInteger sessionId);
|
||||
|
||||
/**
|
||||
* 分页查询当前会话的主线可见消息。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @param query 分页参数
|
||||
* @return 主线消息分页
|
||||
*/
|
||||
ChatHistoryPage pageMainlineMessages(BigInteger sessionId, ChatPageQuery query);
|
||||
|
||||
/**
|
||||
* 查询当前会话的全部主线可见消息。
|
||||
*
|
||||
* @param sessionId 会话 ID
|
||||
* @return 主线消息列表
|
||||
*/
|
||||
List<ChatMessageRecord> listMainlineMessages(BigInteger sessionId);
|
||||
|
||||
List<ChatMessageRecord> getRecentTail(BigInteger sessionId, int limit);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.service.ChatPersistDispatcher;
|
||||
import tech.easyflow.chatlog.service.ChatRoundCommandService;
|
||||
|
||||
/**
|
||||
* 聊天轮次写服务实现。
|
||||
*/
|
||||
@Service
|
||||
public class ChatRoundCommandServiceImpl implements ChatRoundCommandService {
|
||||
|
||||
private final ChatPersistDispatcher chatPersistDispatcher;
|
||||
|
||||
public ChatRoundCommandServiceImpl(ChatPersistDispatcher chatPersistDispatcher) {
|
||||
this.chatPersistDispatcher = chatPersistDispatcher;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) {
|
||||
return chatPersistDispatcher.createOrTouchRound(command);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void selectVariant(ChatRoundSelectCommand command) {
|
||||
chatPersistDispatcher.selectRoundVariant(command);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.service.ChatRoundCommandService;
|
||||
import tech.easyflow.chatlog.service.ChatRoundOperateService;
|
||||
import tech.easyflow.chatlog.service.ChatRoundQueryService;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 聊天轮次业务操作服务实现。
|
||||
*/
|
||||
@Service
|
||||
public class ChatRoundOperateServiceImpl implements ChatRoundOperateService {
|
||||
|
||||
private final ChatRoundQueryService chatRoundQueryService;
|
||||
private final ChatRoundCommandService chatRoundCommandService;
|
||||
|
||||
public ChatRoundOperateServiceImpl(ChatRoundQueryService chatRoundQueryService,
|
||||
ChatRoundCommandService chatRoundCommandService) {
|
||||
this.chatRoundQueryService = chatRoundQueryService;
|
||||
this.chatRoundCommandService = chatRoundCommandService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId) {
|
||||
ChatRoundRecord round = requireLatestRound(sessionId, roundId);
|
||||
if (round.getSelectedAssistantMessageId() == null || round.getSelectedVariantIndex() == null
|
||||
|| round.getSelectedVariantIndex() <= 0) {
|
||||
throw new BusinessException("当前轮次暂无可重答的回答");
|
||||
}
|
||||
return round;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> listVariants(BigInteger sessionId, BigInteger roundId) {
|
||||
ChatRoundRecord round = chatRoundQueryService.getRound(sessionId, roundId);
|
||||
if (round == null) {
|
||||
throw new BusinessException("轮次不存在");
|
||||
}
|
||||
ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(sessionId);
|
||||
boolean switchable = latestRound != null
|
||||
&& Objects.equals(latestRound.getId(), round.getId())
|
||||
&& !ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus());
|
||||
List<ChatMessageRecord> variants = new ArrayList<>(chatRoundQueryService.listRoundVariants(sessionId, roundId));
|
||||
for (ChatMessageRecord variant : variants) {
|
||||
variant.setVariantCount(round.getVariantCount());
|
||||
variant.setSelectedVariantIndex(round.getSelectedVariantIndex());
|
||||
variant.setSwitchable(switchable);
|
||||
}
|
||||
return variants;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId) {
|
||||
ChatRoundRecord round = requireLatestRound(sessionId, roundId);
|
||||
if (variantIndex == null || variantIndex <= 0) {
|
||||
throw new BusinessException("目标答案版本无效");
|
||||
}
|
||||
ChatMessageRecord selected = chatRoundQueryService.getRoundVariant(sessionId, roundId, variantIndex);
|
||||
if (selected == null) {
|
||||
throw new BusinessException("目标答案版本不存在");
|
||||
}
|
||||
|
||||
ChatRoundSelectCommand command = new ChatRoundSelectCommand();
|
||||
command.setSessionId(sessionId);
|
||||
command.setRoundId(roundId);
|
||||
command.setSelectedVariantIndex(variantIndex);
|
||||
command.setSelectedAssistantMessageId(selected.getId());
|
||||
command.setSelectedAssistantMessage(selected);
|
||||
command.setOperatorId(operatorId);
|
||||
chatRoundCommandService.selectVariant(command);
|
||||
selected.setSelectedVariantIndex(variantIndex);
|
||||
selected.setVariantCount(round.getVariantCount());
|
||||
selected.setSwitchable(true);
|
||||
return selected;
|
||||
}
|
||||
|
||||
private ChatRoundRecord requireLatestRound(BigInteger sessionId, BigInteger roundId) {
|
||||
ChatRoundRecord round = chatRoundQueryService.getRound(sessionId, roundId);
|
||||
if (round == null) {
|
||||
throw new BusinessException("轮次不存在");
|
||||
}
|
||||
ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(sessionId);
|
||||
if (latestRound == null || !Objects.equals(latestRound.getId(), round.getId())) {
|
||||
throw new BusinessException("当前轮次已有后续对话,不支持切换答案版本");
|
||||
}
|
||||
if (ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus())) {
|
||||
throw new BusinessException("当前轮次已有后续对话,不支持切换答案版本");
|
||||
}
|
||||
return round;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.cache.ChatHotStateService;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository;
|
||||
import tech.easyflow.chatlog.service.ChatRoundQueryService;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 聊天轮次读服务实现。
|
||||
*/
|
||||
@Service
|
||||
public class ChatRoundQueryServiceImpl implements ChatRoundQueryService {
|
||||
|
||||
private final MySqlChatRoundRepository roundRepository;
|
||||
private final MySqlChatLogRepository logRepository;
|
||||
private final MySqlChatLogTableManager tableManager;
|
||||
private final ChatHotStateService chatHotStateService;
|
||||
|
||||
public ChatRoundQueryServiceImpl(MySqlChatRoundRepository roundRepository,
|
||||
MySqlChatLogRepository logRepository,
|
||||
MySqlChatLogTableManager tableManager,
|
||||
ChatHotStateService chatHotStateService) {
|
||||
this.roundRepository = roundRepository;
|
||||
this.logRepository = logRepository;
|
||||
this.tableManager = tableManager;
|
||||
this.chatHotStateService = chatHotStateService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getLatestRound(BigInteger sessionId) {
|
||||
ChatRoundRecord cached = chatHotStateService.getLatestRound(sessionId);
|
||||
if (cached != null) {
|
||||
return cached;
|
||||
}
|
||||
ChatRoundRecord record = roundRepository.findLatestRound(sessionId);
|
||||
if (record != null) {
|
||||
chatHotStateService.cacheRound(record);
|
||||
}
|
||||
return record;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) {
|
||||
ChatRoundRecord cached = chatHotStateService.getRound(sessionId, roundId);
|
||||
if (cached != null) {
|
||||
return cached;
|
||||
}
|
||||
ChatRoundRecord record = roundRepository.findRound(sessionId, roundId);
|
||||
if (record != null) {
|
||||
chatHotStateService.cacheRound(record);
|
||||
}
|
||||
return record;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> listRoundVariants(BigInteger sessionId, BigInteger roundId) {
|
||||
return logRepository.listRoundVariants(sessionId, roundId, tableManager.listRecentExistingMonths(3));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) {
|
||||
return logRepository.findRoundVariant(sessionId, roundId, variantIndex, tableManager.listRecentExistingMonths(3));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasRounds(BigInteger sessionId) {
|
||||
return roundRepository.existsRounds(sessionId);
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.chatlog.cache.ChatHotStateService;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatHistoryPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
@@ -12,7 +13,12 @@ import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository;
|
||||
import tech.easyflow.chatlog.service.ChatSessionQueryService;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@Service
|
||||
public class ChatSessionQueryServiceImpl implements ChatSessionQueryService {
|
||||
@@ -34,21 +40,7 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService {
|
||||
|
||||
@Override
|
||||
public List<ChatSessionSummary> listSessions(BigInteger userId, BigInteger assistantId, ChatPageQuery query) {
|
||||
if (assistantId == null) {
|
||||
List<BigInteger> sessionIds = chatHotStateService.listSessionIds(userId, query.getOffset(), query.getPageSize());
|
||||
if (!sessionIds.isEmpty()) {
|
||||
List<ChatSessionSummary> cached = chatHotStateService.getSessionSummaries(sessionIds);
|
||||
if (cached.size() == sessionIds.size()) {
|
||||
return cached;
|
||||
}
|
||||
}
|
||||
List<ChatSessionSummary> sessions = sessionRepository.listSessions(userId, null, query);
|
||||
chatHotStateService.cacheSessionSummaries(sessions);
|
||||
return sessions;
|
||||
}
|
||||
List<ChatSessionSummary> sessions = sessionRepository.listSessions(userId, assistantId, query);
|
||||
chatHotStateService.cacheSessionSummaries(sessions);
|
||||
return sessions;
|
||||
return sessionRepository.listSessions(userId, assistantId, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -62,21 +54,6 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService {
|
||||
page.setPageNumber(query.getPageNumber());
|
||||
page.setPageSize(query.getPageSize());
|
||||
|
||||
if (assistantId == null && chatHotStateService.hasSessionIndex(userId)) {
|
||||
List<BigInteger> sessionIds = chatHotStateService.listSessionIds(userId, query.getOffset(), query.getPageSize());
|
||||
if (sessionIds.isEmpty()) {
|
||||
page.setTotal(chatHotStateService.countSessions(userId));
|
||||
page.setRecords(List.of());
|
||||
return page;
|
||||
}
|
||||
List<ChatSessionSummary> cached = chatHotStateService.getSessionSummaries(sessionIds);
|
||||
if (cached.size() == sessionIds.size()) {
|
||||
page.setTotal(chatHotStateService.countSessions(userId));
|
||||
page.setRecords(cached);
|
||||
return page;
|
||||
}
|
||||
}
|
||||
|
||||
page.setTotal(sessionRepository.countSessions(userId, assistantId));
|
||||
page.setRecords(listSessions(userId, assistantId, query));
|
||||
return page;
|
||||
@@ -95,14 +72,74 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService {
|
||||
return summary;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatHistoryPage pageMainlineMessages(BigInteger sessionId, ChatPageQuery query) {
|
||||
ChatHistoryPage page = new ChatHistoryPage();
|
||||
page.setPageNumber(query.getPageNumber());
|
||||
page.setPageSize(query.getPageSize());
|
||||
ChatSessionSummary summary = getSessionSummary(sessionId);
|
||||
long total = summary == null || summary.getMessageCount() == null ? 0L : summary.getMessageCount();
|
||||
List<ChatMessageRecord> records = logRepository.listMainlineMessages(
|
||||
sessionId,
|
||||
tableManager.listRecentExistingMonths(3),
|
||||
query.getOffset(),
|
||||
Math.toIntExact(query.getPageSize())
|
||||
);
|
||||
page.setRecords(records);
|
||||
page.setTotal(Math.max(total, query.getOffset() + records.size()));
|
||||
return page;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> listMainlineMessages(BigInteger sessionId) {
|
||||
return logRepository.listMainlineMessages(sessionId, tableManager.listRecentExistingMonths(3));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> getRecentTail(BigInteger sessionId, int limit) {
|
||||
List<ChatMessageRecord> cached = chatHotStateService.getSessionTail(sessionId);
|
||||
if (cached != null) {
|
||||
if (cached != null && isTailReliable(cached)) {
|
||||
return cached.subList(0, Math.min(limit, cached.size()));
|
||||
}
|
||||
List<ChatMessageRecord> records = logRepository.listRecentTail(sessionId, tableManager.listRecentExistingMonths(3), limit);
|
||||
chatHotStateService.setSessionTail(sessionId, records);
|
||||
return records;
|
||||
}
|
||||
|
||||
/**
|
||||
* 校验 Redis tail 是否符合当前主线版本语义,防止过期选中版本把可见回答过滤掉。
|
||||
*
|
||||
* @param records Redis tail 消息
|
||||
* @return true 表示可直接使用缓存
|
||||
*/
|
||||
private boolean isTailReliable(List<ChatMessageRecord> records) {
|
||||
Map<BigInteger, Integer> selectedVariantByRound = new LinkedHashMap<>();
|
||||
Map<BigInteger, Set<Integer>> assistantVariantsByRound = new LinkedHashMap<>();
|
||||
for (ChatMessageRecord record : records) {
|
||||
if (record == null || record.getRoundId() == null) {
|
||||
continue;
|
||||
}
|
||||
Integer selectedVariantIndex = record.getSelectedVariantIndex();
|
||||
if (selectedVariantIndex != null && selectedVariantIndex > 0) {
|
||||
Integer previous = selectedVariantByRound.putIfAbsent(record.getRoundId(), selectedVariantIndex);
|
||||
if (previous != null && !Objects.equals(previous, selectedVariantIndex)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if ("assistant".equalsIgnoreCase(record.getSenderRole())
|
||||
&& record.getVariantIndex() != null
|
||||
&& record.getVariantIndex() > 0) {
|
||||
assistantVariantsByRound
|
||||
.computeIfAbsent(record.getRoundId(), key -> new LinkedHashSet<>())
|
||||
.add(record.getVariantIndex());
|
||||
}
|
||||
}
|
||||
for (Map.Entry<BigInteger, Integer> entry : selectedVariantByRound.entrySet()) {
|
||||
Set<Integer> visibleVariants = assistantVariantsByRound.get(entry.getKey());
|
||||
if (visibleVariants != null && !visibleVariants.isEmpty() && !visibleVariants.contains(entry.getValue())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,19 @@ import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionExtPayload;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.service.ChatPersistDispatcher;
|
||||
import tech.easyflow.chatlog.service.ChatRoundOperateService;
|
||||
import tech.easyflow.chatlog.service.ChatRoundQueryService;
|
||||
import tech.easyflow.chatlog.service.ChatSessionQueryService;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeHistoryPayloadHelper;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeContext;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeListener;
|
||||
@@ -27,12 +35,21 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
private final SnowFlakeIDKeyGenerator idGenerator = new SnowFlakeIDKeyGenerator();
|
||||
|
||||
private final ChatPersistDispatcher chatPersistDispatcher;
|
||||
private final ChatRoundOperateService chatRoundOperateService;
|
||||
private final ChatRoundQueryService chatRoundQueryService;
|
||||
private final ChatSessionQueryService chatSessionQueryService;
|
||||
private final ChatJsonSupport chatJsonSupport;
|
||||
|
||||
public ChatlogRuntimeListener(ChatPersistDispatcher chatPersistDispatcher,
|
||||
ChatSessionQueryService chatSessionQueryService) {
|
||||
ChatRoundOperateService chatRoundOperateService,
|
||||
ChatRoundQueryService chatRoundQueryService,
|
||||
ChatSessionQueryService chatSessionQueryService,
|
||||
ChatJsonSupport chatJsonSupport) {
|
||||
this.chatPersistDispatcher = chatPersistDispatcher;
|
||||
this.chatRoundOperateService = chatRoundOperateService;
|
||||
this.chatRoundQueryService = chatRoundQueryService;
|
||||
this.chatSessionQueryService = chatSessionQueryService;
|
||||
this.chatJsonSupport = chatJsonSupport;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -48,6 +65,7 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
command.setAssistantCode(context.getAssistantCode());
|
||||
command.setAssistantName(context.getAssistantName());
|
||||
command.setTitle(context.getSessionTitle());
|
||||
command.setExtJson(resolveExtJson(context));
|
||||
command.setOperatorId(defaultNumber(context.getUserId()));
|
||||
chatPersistDispatcher.createOrTouchSession(command);
|
||||
} catch (RuntimeException ex) {
|
||||
@@ -58,6 +76,9 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
@Override
|
||||
public void onUserMessage(ChatRuntimeContext context, ChatRuntimeMessage message) {
|
||||
try {
|
||||
if (prepareRoundContext(context, message)) {
|
||||
return;
|
||||
}
|
||||
chatPersistDispatcher.appendUserMessage(toAppendCommand(context, message));
|
||||
} catch (RuntimeException ex) {
|
||||
throw persistFailed(ex);
|
||||
@@ -67,7 +88,36 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
@Override
|
||||
public void onAssistantCompleted(ChatRuntimeContext context, ChatRuntimeMessage message) {
|
||||
try {
|
||||
applyAssistantRoundMetadata(context, message);
|
||||
chatPersistDispatcher.appendAssistantMessage(toAppendCommand(context, message));
|
||||
chatPersistDispatcher.createOrTouchRound(buildAssistantCompletedRoundCommand(context, message));
|
||||
} catch (RuntimeException ex) {
|
||||
throw persistFailed(ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onChatFailed(ChatRuntimeContext context, Throwable throwable) {
|
||||
try {
|
||||
BigInteger roundId = resolveNumber(context, ChatRuntimeExtKeys.CURRENT_ROUND_ID);
|
||||
if (context == null || context.getSessionId() == null || roundId == null) {
|
||||
return;
|
||||
}
|
||||
ChatRoundRecord currentRound = chatRoundQueryService.getRound(context.getSessionId(), roundId);
|
||||
if (currentRound == null) {
|
||||
return;
|
||||
}
|
||||
ChatRoundUpsertCommand command = new ChatRoundUpsertCommand();
|
||||
command.setRoundId(currentRound.getId());
|
||||
command.setSessionId(currentRound.getSessionId());
|
||||
command.setRoundNo(currentRound.getRoundNo());
|
||||
command.setUserMessageId(currentRound.getUserMessageId());
|
||||
command.setSelectedAssistantMessageId(currentRound.getSelectedAssistantMessageId());
|
||||
command.setSelectedVariantIndex(currentRound.getSelectedVariantIndex());
|
||||
command.setVariantCount(currentRound.getVariantCount());
|
||||
command.setStatus(ChatConstants.ROUND_STATUS_READY);
|
||||
command.setOperatorId(defaultNumber(context.getUserId()));
|
||||
chatPersistDispatcher.createOrTouchRound(command);
|
||||
} catch (RuntimeException ex) {
|
||||
throw persistFailed(ex);
|
||||
}
|
||||
@@ -78,7 +128,15 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
if (context == null || context.getSessionId() == null || limit <= 0) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<ChatMessageRecord> records = new ArrayList<>(chatSessionQueryService.getRecentTail(context.getSessionId(), limit));
|
||||
BigInteger regenerateRoundId = resolveNumber(context, ChatRuntimeExtKeys.REGENERATE_ROUND_ID);
|
||||
int queryLimit = regenerateRoundId == null ? limit : limit + 4;
|
||||
List<ChatMessageRecord> records = new ArrayList<>(chatSessionQueryService.getRecentTail(context.getSessionId(), queryLimit));
|
||||
if (regenerateRoundId != null) {
|
||||
records.removeIf(record -> regenerateRoundId.equals(record.getRoundId()));
|
||||
if (records.size() > limit) {
|
||||
records = new ArrayList<>(records.subList(0, limit));
|
||||
}
|
||||
}
|
||||
Collections.reverse(records);
|
||||
List<ChatRuntimeMessage> messages = new ArrayList<>(records.size());
|
||||
for (ChatMessageRecord record : records) {
|
||||
@@ -118,11 +176,127 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
command.setContentType(message.getContentType());
|
||||
command.setContentText(message.getContentText());
|
||||
command.setContentPayload(message.getContentPayload());
|
||||
command.setRoundId(message.getRoundId());
|
||||
command.setRoundNo(message.getRoundNo());
|
||||
command.setMessageKind(message.getMessageKind());
|
||||
command.setVariantIndex(message.getVariantIndex());
|
||||
command.setCreatedBy(defaultNumber(context.getUserId()));
|
||||
command.setCreated(message.getCreatedAt());
|
||||
return command;
|
||||
}
|
||||
|
||||
private boolean prepareRoundContext(ChatRuntimeContext context, ChatRuntimeMessage message) {
|
||||
if (context == null || message == null || context.getSessionId() == null) {
|
||||
return false;
|
||||
}
|
||||
BigInteger regenerateRoundId = resolveNumber(context, ChatRuntimeExtKeys.REGENERATE_ROUND_ID);
|
||||
if (regenerateRoundId != null) {
|
||||
ChatRoundRecord round = chatRoundOperateService.requireRegeneratableRound(context.getSessionId(), regenerateRoundId);
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_ID, round.getId());
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_NO, round.getRoundNo());
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX, Math.max(round.getVariantCount() + 1, 1));
|
||||
ChatRoundUpsertCommand command = new ChatRoundUpsertCommand();
|
||||
command.setRoundId(round.getId());
|
||||
command.setSessionId(round.getSessionId());
|
||||
command.setRoundNo(round.getRoundNo());
|
||||
command.setUserMessageId(round.getUserMessageId());
|
||||
command.setSelectedAssistantMessageId(round.getSelectedAssistantMessageId());
|
||||
command.setSelectedVariantIndex(round.getSelectedVariantIndex());
|
||||
command.setVariantCount(round.getVariantCount());
|
||||
command.setStatus(ChatConstants.ROUND_STATUS_ANSWERING);
|
||||
command.setOperatorId(defaultNumber(context.getUserId()));
|
||||
chatPersistDispatcher.createOrTouchRound(command);
|
||||
return true;
|
||||
}
|
||||
ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(context.getSessionId());
|
||||
if (latestRound != null && latestRound.getId() != null
|
||||
&& !ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(latestRound.getStatus())) {
|
||||
ChatRoundUpsertCommand lockCommand = new ChatRoundUpsertCommand();
|
||||
lockCommand.setRoundId(latestRound.getId());
|
||||
lockCommand.setSessionId(latestRound.getSessionId());
|
||||
lockCommand.setRoundNo(latestRound.getRoundNo());
|
||||
lockCommand.setUserMessageId(latestRound.getUserMessageId());
|
||||
lockCommand.setSelectedAssistantMessageId(latestRound.getSelectedAssistantMessageId());
|
||||
lockCommand.setSelectedVariantIndex(latestRound.getSelectedVariantIndex());
|
||||
lockCommand.setVariantCount(latestRound.getVariantCount());
|
||||
lockCommand.setStatus(ChatConstants.ROUND_STATUS_LOCKED);
|
||||
lockCommand.setOperatorId(defaultNumber(context.getUserId()));
|
||||
chatPersistDispatcher.createOrTouchRound(lockCommand);
|
||||
}
|
||||
BigInteger roundId = BigInteger.valueOf(idGenerator.nextId());
|
||||
int roundNo = latestRound == null || latestRound.getRoundNo() == null ? 1 : latestRound.getRoundNo() + 1;
|
||||
if (message.getMessageId() == null) {
|
||||
message.setMessageId(BigInteger.valueOf(idGenerator.nextId()));
|
||||
}
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_ID, roundId);
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_NO, roundNo);
|
||||
context.getExt().put(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX, 1);
|
||||
message.setRoundId(roundId);
|
||||
message.setRoundNo(roundNo);
|
||||
message.setMessageKind(ChatConstants.MESSAGE_KIND_USER_PROMPT);
|
||||
message.setVariantIndex(null);
|
||||
ChatRoundUpsertCommand command = new ChatRoundUpsertCommand();
|
||||
command.setRoundId(roundId);
|
||||
command.setSessionId(context.getSessionId());
|
||||
command.setRoundNo(roundNo);
|
||||
command.setUserMessageId(message.getMessageId());
|
||||
command.setSelectedVariantIndex(0);
|
||||
command.setVariantCount(0);
|
||||
command.setStatus(ChatConstants.ROUND_STATUS_ANSWERING);
|
||||
command.setOperatorId(defaultNumber(context.getUserId()));
|
||||
chatPersistDispatcher.createOrTouchRound(command);
|
||||
return false;
|
||||
}
|
||||
|
||||
private void applyAssistantRoundMetadata(ChatRuntimeContext context, ChatRuntimeMessage message) {
|
||||
if (message.getMessageId() == null) {
|
||||
message.setMessageId(BigInteger.valueOf(idGenerator.nextId()));
|
||||
}
|
||||
message.setRoundId(resolveNumber(context, ChatRuntimeExtKeys.CURRENT_ROUND_ID));
|
||||
message.setRoundNo(resolveInteger(context, ChatRuntimeExtKeys.CURRENT_ROUND_NO));
|
||||
message.setVariantIndex(resolveInteger(context, ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX));
|
||||
message.setMessageKind(ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT);
|
||||
}
|
||||
|
||||
private ChatRoundUpsertCommand buildAssistantCompletedRoundCommand(ChatRuntimeContext context, ChatRuntimeMessage message) {
|
||||
ChatRoundUpsertCommand command = new ChatRoundUpsertCommand();
|
||||
command.setRoundId(message.getRoundId());
|
||||
command.setSessionId(context.getSessionId());
|
||||
command.setRoundNo(message.getRoundNo());
|
||||
ChatRoundRecord existing = chatRoundQueryService.getRound(context.getSessionId(), message.getRoundId());
|
||||
if (existing != null) {
|
||||
command.setUserMessageId(existing.getUserMessageId());
|
||||
}
|
||||
command.setSelectedAssistantMessageId(message.getMessageId());
|
||||
command.setSelectedVariantIndex(message.getVariantIndex());
|
||||
command.setVariantCount(message.getVariantIndex());
|
||||
command.setStatus(ChatConstants.ROUND_STATUS_READY);
|
||||
command.setOperatorId(defaultNumber(context.getUserId()));
|
||||
return command;
|
||||
}
|
||||
|
||||
private BigInteger resolveNumber(ChatRuntimeContext context, String key) {
|
||||
if (context == null || context.getExt() == null || key == null) {
|
||||
return null;
|
||||
}
|
||||
Object value = context.getExt().get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
return new BigInteger(String.valueOf(value));
|
||||
}
|
||||
|
||||
private Integer resolveInteger(ChatRuntimeContext context, String key) {
|
||||
if (context == null || context.getExt() == null || key == null) {
|
||||
return null;
|
||||
}
|
||||
Object value = context.getExt().get(key);
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
return Integer.parseInt(String.valueOf(value));
|
||||
}
|
||||
|
||||
private BigInteger defaultNumber(BigInteger value) {
|
||||
return value == null ? BigInteger.ZERO : value;
|
||||
}
|
||||
@@ -133,4 +307,27 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener {
|
||||
}
|
||||
return new BusinessException("聊天记录持久化失败,请稍后重试");
|
||||
}
|
||||
|
||||
private String resolveExtJson(ChatRuntimeContext context) {
|
||||
if (context == null || context.getExt() == null || context.getExt().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
if (!context.getExt().containsKey(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS)) {
|
||||
return null;
|
||||
}
|
||||
Object rawExtraKnowledgeIds = context.getExt().get(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS);
|
||||
if (!(rawExtraKnowledgeIds instanceof List<?> rawList)) {
|
||||
return null;
|
||||
}
|
||||
List<BigInteger> extraKnowledgeIds = new ArrayList<>();
|
||||
for (Object item : rawList) {
|
||||
if (item == null) {
|
||||
continue;
|
||||
}
|
||||
extraKnowledgeIds.add(new BigInteger(String.valueOf(item)));
|
||||
}
|
||||
ChatSessionExtPayload payload = new ChatSessionExtPayload();
|
||||
payload.setExtraKnowledgeIds(extraKnowledgeIds);
|
||||
return chatJsonSupport.toJson(payload);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,18 @@ package tech.easyflow.chatlog.support;
|
||||
public final class ChatConstants {
|
||||
|
||||
public static final String SESSION_TABLE = "chat_session";
|
||||
public static final String ROUND_TABLE = "chat_round";
|
||||
public static final String CHAT_LOG_TEMPLATE = "chat_log_template";
|
||||
public static final String CHAT_LOG_PREFIX = "chat_log_";
|
||||
public static final String CHAT_PERSIST_TOPIC = "chat-persist";
|
||||
public static final String CHAT_PERSIST_GROUP = "chat-persist-group";
|
||||
public static final String CHECKPOINT_SYNC_CODE_SESSION = "chat_session_sync";
|
||||
public static final String CHECKPOINT_SYNC_CODE_LOG = "chat_log_sync";
|
||||
public static final String ROUND_STATUS_ANSWERING = "ANSWERING";
|
||||
public static final String ROUND_STATUS_READY = "READY";
|
||||
public static final String ROUND_STATUS_LOCKED = "LOCKED";
|
||||
public static final String MESSAGE_KIND_USER_PROMPT = "USER_PROMPT";
|
||||
public static final String MESSAGE_KIND_ASSISTANT_VARIANT = "ASSISTANT_VARIANT";
|
||||
|
||||
private ChatConstants() {
|
||||
}
|
||||
|
||||
@@ -1,18 +1,29 @@
|
||||
package tech.easyflow.chatlog.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionSummaryCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.event.ChatPersistEvent;
|
||||
import tech.easyflow.chatlog.domain.event.ChatPersistEventType;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.time.YearMonth;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
public class ChatPersistMySqlApplyServiceTest {
|
||||
|
||||
private final ChatPersistMySqlApplyService service =
|
||||
new ChatPersistMySqlApplyService(null, null, null, null);
|
||||
new ChatPersistMySqlApplyService(null, null, null, null, new ChatJsonSupport(new ObjectMapper()));
|
||||
|
||||
@Test
|
||||
public void shouldBuildMissingSessionUpsertFromMessageMetadata() {
|
||||
@@ -69,4 +80,101 @@ public class ChatPersistMySqlApplyServiceTest {
|
||||
Assert.assertEquals("会话-202", upsert.getTitle());
|
||||
Assert.assertEquals(BigInteger.valueOf(7), upsert.getOperatorId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNotDoubleCountSummaryWhenMessageEventReplayed() {
|
||||
ChatJsonSupport jsonSupport = new ChatJsonSupport(new ObjectMapper());
|
||||
FakeSessionRepository sessionRepository = new FakeSessionRepository();
|
||||
FakeLogRepository logRepository = new FakeLogRepository(jsonSupport);
|
||||
ChatPersistMySqlApplyService applyService = new ChatPersistMySqlApplyService(
|
||||
sessionRepository,
|
||||
logRepository,
|
||||
new FakeRoundRepository(),
|
||||
new FakeTableManager(),
|
||||
jsonSupport
|
||||
);
|
||||
ChatAppendMessageCommand command = new ChatAppendMessageCommand();
|
||||
command.setMessageId(BigInteger.valueOf(301));
|
||||
command.setSessionId(BigInteger.valueOf(401));
|
||||
command.setTenantId(BigInteger.ONE);
|
||||
command.setDeptId(BigInteger.ONE);
|
||||
command.setUserId(BigInteger.valueOf(7));
|
||||
command.setAssistantId(BigInteger.valueOf(8));
|
||||
command.setSenderId(BigInteger.valueOf(7));
|
||||
command.setSenderName("admin");
|
||||
command.setSenderRole("user");
|
||||
command.setContentText("第一条消息");
|
||||
command.setCreatedBy(BigInteger.valueOf(7));
|
||||
command.setCreated(new Date(4_000L));
|
||||
|
||||
ChatPersistEvent event = new ChatPersistEvent();
|
||||
event.setEventId("message-301");
|
||||
event.setEventType(ChatPersistEventType.USER_MESSAGE_APPENDED);
|
||||
event.setSessionId(command.getSessionId());
|
||||
event.setPayload(jsonSupport.toJson(command));
|
||||
|
||||
applyService.apply(List.of(event));
|
||||
applyService.apply(List.of(event));
|
||||
|
||||
Assert.assertEquals(1, sessionRepository.summaryCommands.size());
|
||||
ChatSessionSummaryCommand summaryCommand = sessionRepository.summaryCommands.get(0);
|
||||
Assert.assertEquals(1, summaryCommand.getMessageIncrement());
|
||||
Assert.assertEquals("第一条消息", summaryCommand.getLastMessagePreview());
|
||||
Assert.assertEquals(new Date(4_000L), summaryCommand.getLastMessageAt());
|
||||
Assert.assertEquals(new Date(4_000L), summaryCommand.getAccessAt());
|
||||
}
|
||||
|
||||
private static final class FakeSessionRepository extends MySqlChatSessionRepository {
|
||||
|
||||
private final List<ChatSessionSummaryCommand> summaryCommands = new ArrayList<>();
|
||||
|
||||
private FakeSessionRepository() {
|
||||
super(null, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createOrTouchBatch(List<ChatSessionUpsertCommand> commands) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateSummaries(List<ChatSessionSummaryCommand> commands) {
|
||||
summaryCommands.addAll(commands);
|
||||
}
|
||||
}
|
||||
|
||||
private static final class FakeLogRepository extends MySqlChatLogRepository {
|
||||
|
||||
private boolean inserted;
|
||||
|
||||
private FakeLogRepository(ChatJsonSupport jsonSupport) {
|
||||
super(null, null, jsonSupport);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatAppendMessageCommand> appendMessages(List<ChatAppendMessageCommand> commands) {
|
||||
if (inserted) {
|
||||
return List.of();
|
||||
}
|
||||
inserted = true;
|
||||
return commands;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class FakeRoundRepository extends MySqlChatRoundRepository {
|
||||
|
||||
private FakeRoundRepository() {
|
||||
super(null);
|
||||
}
|
||||
}
|
||||
|
||||
private static final class FakeTableManager extends MySqlChatLogTableManager {
|
||||
|
||||
private FakeTableManager() {
|
||||
super(null, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void ensureMonthTable(YearMonth month) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand;
|
||||
import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.service.ChatRoundCommandService;
|
||||
import tech.easyflow.chatlog.service.ChatRoundQueryService;
|
||||
import tech.easyflow.chatlog.support.ChatConstants;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ChatRoundOperateServiceImpl} 单元测试。
|
||||
*/
|
||||
public class ChatRoundOperateServiceImplTest {
|
||||
|
||||
/**
|
||||
* 切换答案版本时应精准查询目标版本,避免先加载全部版本再过滤。
|
||||
*/
|
||||
@Test
|
||||
public void selectVariantShouldReadTargetVariantDirectly() {
|
||||
FakeRoundQueryService queryService = new FakeRoundQueryService();
|
||||
queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_READY);
|
||||
queryService.latestRound = queryService.round;
|
||||
queryService.targetVariant = message(BigInteger.valueOf(3002), 2);
|
||||
FakeRoundCommandService commandService = new FakeRoundCommandService();
|
||||
ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, commandService);
|
||||
|
||||
ChatMessageRecord selected = service.selectVariant(
|
||||
BigInteger.valueOf(1001),
|
||||
BigInteger.valueOf(2001),
|
||||
2,
|
||||
BigInteger.valueOf(7)
|
||||
);
|
||||
|
||||
Assert.assertEquals(BigInteger.valueOf(3002), selected.getId());
|
||||
Assert.assertEquals(Integer.valueOf(2), selected.getSelectedVariantIndex());
|
||||
Assert.assertEquals(Integer.valueOf(2), selected.getVariantCount());
|
||||
Assert.assertEquals(Boolean.TRUE, selected.getSwitchable());
|
||||
Assert.assertEquals(1, queryService.getRoundVariantCalls);
|
||||
Assert.assertEquals(0, queryService.listRoundVariantsCalls);
|
||||
Assert.assertNotNull(commandService.selectedCommand);
|
||||
Assert.assertEquals(BigInteger.valueOf(3002), commandService.selectedCommand.getSelectedAssistantMessageId());
|
||||
}
|
||||
|
||||
/**
|
||||
* 列出答案版本时应由业务层统一补齐当前选中态和可切换状态。
|
||||
*/
|
||||
@Test
|
||||
public void listVariantsShouldFillVariantMetadata() {
|
||||
FakeRoundQueryService queryService = new FakeRoundQueryService();
|
||||
queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_READY);
|
||||
queryService.latestRound = queryService.round;
|
||||
queryService.variants = List.of(message(BigInteger.valueOf(3001), 1), message(BigInteger.valueOf(3002), 2));
|
||||
ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, new FakeRoundCommandService());
|
||||
|
||||
List<ChatMessageRecord> variants = service.listVariants(BigInteger.valueOf(1001), BigInteger.valueOf(2001));
|
||||
|
||||
Assert.assertEquals(2, variants.size());
|
||||
for (ChatMessageRecord variant : variants) {
|
||||
Assert.assertEquals(Integer.valueOf(2), variant.getVariantCount());
|
||||
Assert.assertEquals(Integer.valueOf(2), variant.getSelectedVariantIndex());
|
||||
Assert.assertEquals(Boolean.TRUE, variant.getSwitchable());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 已锁定轮次禁止切换,避免改变已有后续上下文。
|
||||
*/
|
||||
@Test(expected = BusinessException.class)
|
||||
public void selectVariantShouldRejectLockedRound() {
|
||||
FakeRoundQueryService queryService = new FakeRoundQueryService();
|
||||
queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_LOCKED);
|
||||
queryService.latestRound = queryService.round;
|
||||
ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, new FakeRoundCommandService());
|
||||
|
||||
service.selectVariant(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 1, BigInteger.valueOf(7));
|
||||
}
|
||||
|
||||
/**
|
||||
* 新增迁移必须为热表版本切换查询补齐索引。
|
||||
*
|
||||
* @throws Exception 读取迁移文件失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void migrationShouldCreateRoundVariantIndex() throws Exception {
|
||||
String sql = Files.readString(
|
||||
resolveMigrationPath("V18__mysql_chat_round_variant_index.sql"),
|
||||
StandardCharsets.UTF_8
|
||||
);
|
||||
|
||||
Assert.assertTrue(sql.contains("idx_chat_log_round_variant"));
|
||||
Assert.assertTrue(sql.contains("`session_id`, `round_id`, `message_kind`, `variant_index`, `created`, `id`"));
|
||||
Assert.assertFalse(sql.contains("V16__mysql_chat_round_variant"));
|
||||
}
|
||||
|
||||
/**
|
||||
* 从当前测试工作目录向上查找迁移文件,兼容根工程与模块工程两种运行方式。
|
||||
*
|
||||
* @param fileName 迁移文件名
|
||||
* @return 迁移文件路径
|
||||
* @throws Exception 未找到迁移文件时抛出
|
||||
*/
|
||||
private static Path resolveMigrationPath(String fileName) throws Exception {
|
||||
Path current = Path.of("").toAbsolutePath();
|
||||
while (current != null) {
|
||||
Path candidate = current.resolve(
|
||||
"easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/" + fileName
|
||||
);
|
||||
if (Files.exists(candidate)) {
|
||||
return candidate;
|
||||
}
|
||||
current = current.getParent();
|
||||
}
|
||||
throw new java.nio.file.NoSuchFileException(fileName);
|
||||
}
|
||||
|
||||
private static ChatRoundRecord round(BigInteger sessionId, BigInteger roundId, int selectedVariantIndex, String status) {
|
||||
ChatRoundRecord round = new ChatRoundRecord();
|
||||
round.setId(roundId);
|
||||
round.setSessionId(sessionId);
|
||||
round.setRoundNo(1);
|
||||
round.setSelectedVariantIndex(selectedVariantIndex);
|
||||
round.setVariantCount(2);
|
||||
round.setStatus(status);
|
||||
return round;
|
||||
}
|
||||
|
||||
private static ChatMessageRecord message(BigInteger id, int variantIndex) {
|
||||
ChatMessageRecord record = new ChatMessageRecord();
|
||||
record.setId(id);
|
||||
record.setSessionId(BigInteger.valueOf(1001));
|
||||
record.setRoundId(BigInteger.valueOf(2001));
|
||||
record.setVariantIndex(variantIndex);
|
||||
record.setSenderRole("assistant");
|
||||
record.setMessageKind(ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT);
|
||||
record.setContentText("答案 " + variantIndex);
|
||||
return record;
|
||||
}
|
||||
|
||||
/**
|
||||
* 轮次读服务测试替身。
|
||||
*/
|
||||
private static final class FakeRoundQueryService implements ChatRoundQueryService {
|
||||
|
||||
private ChatRoundRecord round;
|
||||
private ChatRoundRecord latestRound;
|
||||
private ChatMessageRecord targetVariant;
|
||||
private List<ChatMessageRecord> variants = List.of();
|
||||
private int getRoundVariantCalls;
|
||||
private int listRoundVariantsCalls;
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getLatestRound(BigInteger sessionId) {
|
||||
return latestRound;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) {
|
||||
return round;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> listRoundVariants(BigInteger sessionId, BigInteger roundId) {
|
||||
listRoundVariantsCalls += 1;
|
||||
return variants;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) {
|
||||
getRoundVariantCalls += 1;
|
||||
return targetVariant;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasRounds(BigInteger sessionId) {
|
||||
return round != null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 轮次写服务测试替身。
|
||||
*/
|
||||
private static final class FakeRoundCommandService implements ChatRoundCommandService {
|
||||
|
||||
private ChatRoundSelectCommand selectedCommand;
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void selectVariant(ChatRoundSelectCommand command) {
|
||||
selectedCommand = command;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.chatlog.cache.ChatHotStateService;
|
||||
import tech.easyflow.chatlog.config.ChatCacheProperties;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatHistoryPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionPage;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.chatlog.domain.query.ChatPageQuery;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager;
|
||||
import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.time.YearMonth;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ChatSessionQueryServiceImpl} 单元测试。
|
||||
*/
|
||||
public class ChatSessionQueryServiceImplTest {
|
||||
|
||||
/**
|
||||
* 会话列表必须以 MySQL 会话表为唯一权威来源,不再使用 Redis 列表索引。
|
||||
*/
|
||||
@Test
|
||||
public void pageSessionsShouldUseMysqlRepositoryAsAuthority() {
|
||||
FakeSessionRepository sessionRepository = new FakeSessionRepository();
|
||||
sessionRepository.sessions = List.of(session(BigInteger.valueOf(1001), 4));
|
||||
sessionRepository.count = 1;
|
||||
ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl(
|
||||
sessionRepository,
|
||||
new FakeLogRepository(),
|
||||
new FakeTableManager(List.of()),
|
||||
new FakeHotStateService()
|
||||
);
|
||||
|
||||
ChatSessionPage page = service.pageSessions(BigInteger.valueOf(7), null, new ChatPageQuery());
|
||||
|
||||
Assert.assertEquals(1, page.getTotal());
|
||||
Assert.assertEquals(1, page.getRecords().size());
|
||||
Assert.assertEquals(1, sessionRepository.countSessionsCalls);
|
||||
Assert.assertEquals(1, sessionRepository.listSessionsCalls);
|
||||
}
|
||||
|
||||
/**
|
||||
* 工作台消息分页必须走 MySQL 热表主线查询,并保持分页参数语义。
|
||||
*/
|
||||
@Test
|
||||
public void pageMainlineMessagesShouldReadMysqlHotTables() {
|
||||
FakeSessionRepository sessionRepository = new FakeSessionRepository();
|
||||
sessionRepository.summary = session(BigInteger.valueOf(2001), 6);
|
||||
FakeLogRepository logRepository = new FakeLogRepository();
|
||||
logRepository.records = List.of(message(3001), message(3002));
|
||||
List<YearMonth> months = List.of(YearMonth.of(2026, 5));
|
||||
ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl(
|
||||
sessionRepository,
|
||||
logRepository,
|
||||
new FakeTableManager(months),
|
||||
new FakeHotStateService()
|
||||
);
|
||||
ChatPageQuery query = new ChatPageQuery();
|
||||
query.setPageNumber(2);
|
||||
query.setPageSize(2);
|
||||
|
||||
ChatHistoryPage page = service.pageMainlineMessages(BigInteger.valueOf(2001), query);
|
||||
|
||||
Assert.assertEquals(6, page.getTotal());
|
||||
Assert.assertEquals(2, page.getRecords().size());
|
||||
Assert.assertEquals(BigInteger.valueOf(2001), logRepository.capturedSessionId);
|
||||
Assert.assertEquals(months, logRepository.capturedMonths);
|
||||
Assert.assertEquals(2, logRepository.capturedOffset);
|
||||
Assert.assertEquals(2, logRepository.capturedLimit);
|
||||
}
|
||||
|
||||
/**
|
||||
* 当 MySQL 摘要计数滞后时,分页 total 至少覆盖当前已返回的数据范围。
|
||||
*/
|
||||
@Test
|
||||
public void pageMainlineMessagesShouldNotReturnTotalSmallerThanCurrentPage() {
|
||||
FakeSessionRepository sessionRepository = new FakeSessionRepository();
|
||||
sessionRepository.summary = session(BigInteger.valueOf(2002), 1);
|
||||
FakeLogRepository logRepository = new FakeLogRepository();
|
||||
logRepository.records = List.of(message(4001), message(4002));
|
||||
ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl(
|
||||
sessionRepository,
|
||||
logRepository,
|
||||
new FakeTableManager(List.of(YearMonth.of(2026, 5))),
|
||||
new FakeHotStateService()
|
||||
);
|
||||
ChatPageQuery query = new ChatPageQuery();
|
||||
query.setPageNumber(2);
|
||||
query.setPageSize(2);
|
||||
|
||||
ChatHistoryPage page = service.pageMainlineMessages(BigInteger.valueOf(2002), query);
|
||||
|
||||
Assert.assertEquals(4, page.getTotal());
|
||||
}
|
||||
|
||||
private static ChatSessionSummary session(BigInteger id, int messageCount) {
|
||||
ChatSessionSummary summary = new ChatSessionSummary();
|
||||
summary.setId(id);
|
||||
summary.setUserId(BigInteger.valueOf(7));
|
||||
summary.setMessageCount(messageCount);
|
||||
return summary;
|
||||
}
|
||||
|
||||
private static ChatMessageRecord message(long id) {
|
||||
ChatMessageRecord record = new ChatMessageRecord();
|
||||
record.setId(BigInteger.valueOf(id));
|
||||
return record;
|
||||
}
|
||||
|
||||
/**
|
||||
* MySQL 会话仓储测试替身。
|
||||
*/
|
||||
private static final class FakeSessionRepository extends MySqlChatSessionRepository {
|
||||
|
||||
private long count;
|
||||
private int countSessionsCalls;
|
||||
private int listSessionsCalls;
|
||||
private ChatSessionSummary summary;
|
||||
private List<ChatSessionSummary> sessions = new ArrayList<>();
|
||||
|
||||
private FakeSessionRepository() {
|
||||
super(null, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatSessionSummary> listSessions(BigInteger userId, BigInteger assistantId, ChatPageQuery query) {
|
||||
listSessionsCalls += 1;
|
||||
return sessions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long countSessions(BigInteger userId, BigInteger assistantId) {
|
||||
countSessionsCalls += 1;
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatSessionSummary findBySessionId(BigInteger sessionId) {
|
||||
return summary;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* MySQL 消息仓储测试替身。
|
||||
*/
|
||||
private static final class FakeLogRepository extends MySqlChatLogRepository {
|
||||
|
||||
private BigInteger capturedSessionId;
|
||||
private List<YearMonth> capturedMonths;
|
||||
private long capturedOffset;
|
||||
private int capturedLimit;
|
||||
private List<ChatMessageRecord> records = new ArrayList<>();
|
||||
|
||||
private FakeLogRepository() {
|
||||
super(null, null, new ChatJsonSupport(new ObjectMapper()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> listMainlineMessages(BigInteger sessionId, List<YearMonth> months, long offset, int limit) {
|
||||
capturedSessionId = sessionId;
|
||||
capturedMonths = months;
|
||||
capturedOffset = offset;
|
||||
capturedLimit = limit;
|
||||
return records;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* MySQL 热表管理器测试替身。
|
||||
*/
|
||||
private static final class FakeTableManager extends MySqlChatLogTableManager {
|
||||
|
||||
private final List<YearMonth> months;
|
||||
|
||||
private FakeTableManager(List<YearMonth> months) {
|
||||
super(null, null);
|
||||
this.months = months;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<YearMonth> listRecentExistingMonths(int retentionMonths) {
|
||||
return months;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Redis 热态测试替身,避免单测依赖真实 Redis。
|
||||
*/
|
||||
private static final class FakeHotStateService extends ChatHotStateService {
|
||||
|
||||
private FakeHotStateService() {
|
||||
super(null, new ObjectMapper(), new ChatCacheProperties());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatSessionSummary getSessionSummary(BigInteger sessionId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cacheSessionSummary(ChatSessionSummary summary) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> getSessionTail(BigInteger sessionId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setSessionTail(BigInteger sessionId, List<ChatMessageRecord> records) {
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package tech.easyflow.chatlog.service.impl;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatRoundRecord;
|
||||
import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatMessageRecord;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionExtPayload;
|
||||
import tech.easyflow.chatlog.domain.dto.ChatSessionSummary;
|
||||
import tech.easyflow.chatlog.service.ChatPersistDispatcher;
|
||||
import tech.easyflow.chatlog.service.ChatRoundOperateService;
|
||||
import tech.easyflow.chatlog.service.ChatRoundQueryService;
|
||||
import tech.easyflow.chatlog.service.ChatSessionQueryService;
|
||||
import tech.easyflow.chatlog.support.ChatJsonSupport;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeContext;
|
||||
import tech.easyflow.core.runtime.ChatRuntimeExtKeys;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ChatlogRuntimeListener} 单元测试。
|
||||
*/
|
||||
public class ChatlogRuntimeListenerTest {
|
||||
|
||||
/**
|
||||
* 会话准备阶段应把额外知识库选择写入 ext_json。
|
||||
*/
|
||||
@Test
|
||||
public void onSessionPreparedShouldWriteExtraKnowledgeIdsToExtJson() {
|
||||
CapturingChatPersistDispatcher dispatcher = new CapturingChatPersistDispatcher();
|
||||
ChatlogRuntimeListener listener = new ChatlogRuntimeListener(
|
||||
dispatcher,
|
||||
new NoopChatRoundOperateService(),
|
||||
new NoopChatRoundQueryService(),
|
||||
new NoopChatSessionQueryService(),
|
||||
new ChatJsonSupport(new ObjectMapper())
|
||||
);
|
||||
ChatRuntimeContext context = new ChatRuntimeContext();
|
||||
context.setSessionId(BigInteger.valueOf(1001));
|
||||
context.setTenantId(BigInteger.ONE);
|
||||
context.setDeptId(BigInteger.TEN);
|
||||
context.setUserId(BigInteger.valueOf(7));
|
||||
context.setUserAccount("admin");
|
||||
context.setAssistantId(BigInteger.valueOf(88));
|
||||
context.setAssistantCode("bot-88");
|
||||
context.setAssistantName("测试助手");
|
||||
context.setSessionTitle("你好");
|
||||
context.getExt().put(
|
||||
ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS,
|
||||
List.of(BigInteger.valueOf(11), BigInteger.valueOf(22))
|
||||
);
|
||||
|
||||
listener.onSessionPrepared(context);
|
||||
|
||||
Assert.assertNotNull(dispatcher.captured);
|
||||
ChatSessionExtPayload payload = new ChatJsonSupport(new ObjectMapper())
|
||||
.fromJson(dispatcher.captured.getExtJson(), ChatSessionExtPayload.class);
|
||||
Assert.assertEquals(
|
||||
List.of(BigInteger.valueOf(11), BigInteger.valueOf(22)),
|
||||
payload.getExtraKnowledgeIds()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新生成时历史上下文应排除当前轮旧问题和旧答案。
|
||||
*/
|
||||
@Test
|
||||
public void loadMessagesShouldExcludeRegenerateRoundHistory() {
|
||||
ChatlogRuntimeListener listener = new ChatlogRuntimeListener(
|
||||
null,
|
||||
new NoopChatRoundOperateService(),
|
||||
new NoopChatRoundQueryService(),
|
||||
new TailChatSessionQueryService(List.of(
|
||||
record(4, 2, "assistant", "旧答案"),
|
||||
record(3, 2, "user", "当前问题"),
|
||||
record(2, 1, "assistant", "上一轮答案"),
|
||||
record(1, 1, "user", "上一轮问题")
|
||||
)),
|
||||
new ChatJsonSupport(new ObjectMapper())
|
||||
);
|
||||
ChatRuntimeContext context = new ChatRuntimeContext();
|
||||
context.setSessionId(BigInteger.valueOf(1001));
|
||||
context.getExt().put(ChatRuntimeExtKeys.REGENERATE_ROUND_ID, BigInteger.valueOf(2));
|
||||
|
||||
List<tech.easyflow.core.runtime.ChatRuntimeMessage> messages = listener.loadMessages(context, 10);
|
||||
|
||||
Assert.assertEquals(2, messages.size());
|
||||
Assert.assertEquals("上一轮问题", messages.get(0).getContentText());
|
||||
Assert.assertEquals("上一轮答案", messages.get(1).getContentText());
|
||||
}
|
||||
|
||||
private static ChatMessageRecord record(long id, int roundId, String role, String text) {
|
||||
ChatMessageRecord record = new ChatMessageRecord();
|
||||
record.setId(BigInteger.valueOf(id));
|
||||
record.setSessionId(BigInteger.valueOf(1001));
|
||||
record.setRoundId(BigInteger.valueOf(roundId));
|
||||
record.setSenderRole(role);
|
||||
record.setContentType("TEXT");
|
||||
record.setContentText(text);
|
||||
record.setCreated(new Date(id));
|
||||
return record;
|
||||
}
|
||||
|
||||
private static class CapturingChatPersistDispatcher extends ChatPersistDispatcher {
|
||||
|
||||
private ChatSessionUpsertCommand captured;
|
||||
|
||||
private CapturingChatPersistDispatcher() {
|
||||
super(null, null, null, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatSessionSummary createOrTouchSession(ChatSessionUpsertCommand command) {
|
||||
this.captured = command;
|
||||
return new ChatSessionSummary();
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoopChatRoundOperateService implements ChatRoundOperateService {
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<tech.easyflow.chatlog.domain.dto.ChatMessageRecord> listVariants(BigInteger sessionId, BigInteger roundId) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public tech.easyflow.chatlog.domain.dto.ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoopChatRoundQueryService implements ChatRoundQueryService {
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getLatestRound(BigInteger sessionId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<tech.easyflow.chatlog.domain.dto.ChatMessageRecord> listRoundVariants(BigInteger sessionId, BigInteger roundId) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public tech.easyflow.chatlog.domain.dto.ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasRounds(BigInteger sessionId) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static class NoopChatSessionQueryService implements ChatSessionQueryService {
|
||||
|
||||
@Override
|
||||
public List<ChatSessionSummary> listSessions(BigInteger userId, BigInteger assistantId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long countSessions(BigInteger userId, BigInteger assistantId) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public tech.easyflow.chatlog.domain.dto.ChatSessionPage pageSessions(BigInteger userId, BigInteger assistantId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) {
|
||||
return new tech.easyflow.chatlog.domain.dto.ChatSessionPage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatSessionSummary getSessionSummary(BigInteger sessionId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public tech.easyflow.chatlog.domain.dto.ChatHistoryPage pageMainlineMessages(BigInteger sessionId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) {
|
||||
return new tech.easyflow.chatlog.domain.dto.ChatHistoryPage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<tech.easyflow.chatlog.domain.dto.ChatMessageRecord> listMainlineMessages(BigInteger sessionId) {
|
||||
return List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<tech.easyflow.chatlog.domain.dto.ChatMessageRecord> getRecentTail(BigInteger sessionId, int limit) {
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
private static class TailChatSessionQueryService extends NoopChatSessionQueryService {
|
||||
|
||||
private final List<ChatMessageRecord> records;
|
||||
|
||||
private TailChatSessionQueryService(List<ChatMessageRecord> records) {
|
||||
this.records = records;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessageRecord> getRecentTail(BigInteger sessionId, int limit) {
|
||||
return records.subList(0, Math.min(records.size(), limit));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user