diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/BotController.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/BotController.java index e69a656..534013a 100644 --- a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/BotController.java +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/BotController.java @@ -13,13 +13,17 @@ import com.mybatisflex.core.query.QueryWrapper; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.util.StringUtils; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import tech.easyflow.admin.controller.ai.support.AiResourceCreatorNameSupport; +import tech.easyflow.admin.service.ai.ChatWorkspaceService; import tech.easyflow.ai.chattime.availability.ChatTimeToolAvailabilityContext; import tech.easyflow.ai.easyagents.listener.PromptChoreChatStreamListener; import tech.easyflow.ai.entity.*; +import tech.easyflow.ai.enums.PublishStatus; import tech.easyflow.ai.publish.BotPublishAppService; import tech.easyflow.approval.entity.vo.ApprovalActionResult; import tech.easyflow.ai.service.*; @@ -31,9 +35,11 @@ import tech.easyflow.common.satoken.util.SaTokenUtil; import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.jsonbody.JsonBody; +import tech.easyflow.chatlog.service.ChatRoundOperateService; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseUtil; import tech.easyflow.core.runtime.ChatChannel; +import tech.easyflow.core.runtime.ChatRuntimeExtKeys; import tech.easyflow.core.runtime.ChatRuntimeContext; import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot; import tech.easyflow.system.service.CategoryPermissionService; @@ -74,9 +80,13 @@ public class BotController extends BaseCurdController { @Resource private BotPublishAppService botPublishAppService; @Resource + private ChatRoundOperateService chatRoundOperateService; + @Resource private AiResourceApprovalStateService aiResourceApprovalStateService; @Resource private AiResourceCreatorNameSupport aiResourceCreatorNameSupport; + @Resource + private ChatWorkspaceService chatWorkspaceService; public BotController(BotService service, ModelService modelService, BotWorkflowService botWorkflowService, BotDocumentCollectionService botDocumentCollectionService, BotMessageService botMessageService) { @@ -162,13 +172,30 @@ public class BotController extends BaseCurdController { @JsonBody(value = "botId", required = true) BigInteger botId, @JsonBody(value = "conversationId", required = true) BigInteger conversationId, @JsonBody(value = "messages") List> messages, - @JsonBody(value = "attachments") List attachments + @JsonBody(value = "attachments") List attachments, + @JsonBody(value = "publishedOnly") Boolean publishedOnly, + @JsonBody(value = "extraKnowledgeIds") List extraKnowledgeIds, + @JsonBody(value = "regenerateRoundId") BigInteger regenerateRoundId ) { + boolean usePublishedOnly = Boolean.TRUE.equals(publishedOnly); BotServiceImpl.ChatCheckResult chatCheckResult = new BotServiceImpl.ChatCheckResult(); + if (usePublishedOnly) { + chatWorkspaceService.assertSessionContinuable(requireCurrentLoginAccount(), conversationId, botId); + } + if (regenerateRoundId != null) { + chatRoundOperateService.requireRegeneratableRound(conversationId, regenerateRoundId); + } // 前置校验:失败则直接返回错误SseEmitter - SseEmitter errorEmitter = botService.checkChatBeforeStart(botId, prompt, conversationId.toString(), chatCheckResult); + SseEmitter errorEmitter = botService.checkChatBeforeStart( + botId, + prompt, + conversationId.toString(), + chatCheckResult, + usePublishedOnly, + regenerateRoundId + ); if (errorEmitter != null) { return errorEmitter; } @@ -179,7 +206,7 @@ public class BotController extends BaseCurdController { messages, chatCheckResult, attachments, - buildRuntimeContext(chatCheckResult.getAiBot(), conversationId, prompt, attachments) + buildRuntimeContext(chatCheckResult.getAiBot(), conversationId, prompt, attachments, extraKnowledgeIds, regenerateRoundId) ); } @@ -194,16 +221,24 @@ public class BotController extends BaseCurdController { @GetMapping("getDetail") @SaIgnore public Result getDetail(String id) { - Bot bot = StpUtil.isLogin() ? botService.getDetail(id) : botService.getPublishedDetail(id); - if (bot != null && StpUtil.isLogin()) { - categoryPermissionService.assertCategoryResourceVisible("BOT", bot.getCreatedBy(), bot.getCategoryId(), "无权限访问聊天助手"); + boolean publishedOnly = isPublishedOnlyRequest(); + Bot rawBot = StpUtil.isLogin() ? botService.getDetail(id) : botService.getPublishedDetail(id); + if (rawBot != null && StpUtil.isLogin()) { + categoryPermissionService.assertCategoryResourceVisible("BOT", rawBot.getCreatedBy(), rawBot.getCategoryId(), "无权限访问聊天助手"); } - if (bot == null) { + if (rawBot == null) { return Result.ok(null); } - if (!StpUtil.isLogin() && !tech.easyflow.ai.enums.PublishStatus.from(bot.getPublishStatus()).isExternallyVisible()) { + if (!StpUtil.isLogin() && !PublishStatus.from(rawBot.getPublishStatus()).isExternallyVisible()) { throw new BusinessException("聊天助手尚未发布"); } + Bot bot = rawBot; + if (publishedOnly && StpUtil.isLogin()) { + if (PublishStatus.from(rawBot.getPublishStatus()) != PublishStatus.PUBLISHED) { + throw new BusinessException("聊天助手尚未发布"); + } + bot = botService.toPublishedView(rawBot); + } if (StpUtil.isLogin()) { aiResourceApprovalStateService.fillBotApprovalState(bot); } @@ -213,17 +248,25 @@ public class BotController extends BaseCurdController { @Override @SaIgnore public Result detail(String id) { - Bot data = StpUtil.isLogin() ? botService.getDetail(id) : botService.getPublishedDetail(id); - if (data == null) { - return Result.ok(data); + boolean publishedOnly = isPublishedOnlyRequest(); + Bot rawData = StpUtil.isLogin() ? botService.getDetail(id) : botService.getPublishedDetail(id); + if (rawData == null) { + return Result.ok(rawData); } if (StpUtil.isLogin()) { - categoryPermissionService.assertCategoryResourceVisible("BOT", data.getCreatedBy(), data.getCategoryId(), "无权限访问聊天助手"); + categoryPermissionService.assertCategoryResourceVisible("BOT", rawData.getCreatedBy(), rawData.getCategoryId(), "无权限访问聊天助手"); } - if (!StpUtil.isLogin() && !tech.easyflow.ai.enums.PublishStatus.from(data.getPublishStatus()).isExternallyVisible()) { + if (!StpUtil.isLogin() && !PublishStatus.from(rawData.getPublishStatus()).isExternallyVisible()) { throw new BusinessException("聊天助手尚未发布"); } + Bot data = rawData; + if (publishedOnly && StpUtil.isLogin()) { + if (PublishStatus.from(rawData.getPublishStatus()) != PublishStatus.PUBLISHED) { + throw new BusinessException("聊天助手尚未发布"); + } + data = botService.toPublishedView(rawData); + } Map llmOptions = data.getModelOptions(); if (llmOptions == null) { @@ -298,8 +341,12 @@ public class BotController extends BaseCurdController { public Result> list(Bot entity, Boolean asTree, String sortKey, String sortType) { QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity)); applyCategoryPermission(queryWrapper); + applyPublishedOnlyFilter(queryWrapper); queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy())); List bots = service.list(queryWrapper); + if (isPublishedOnlyRequest()) { + bots = bots.stream().map(botService::toPublishedView).toList(); + } aiResourceApprovalStateService.fillBotApprovalState(bots); return Result.ok(bots); } @@ -307,7 +354,11 @@ public class BotController extends BaseCurdController { @Override protected Page queryPage(Page page, QueryWrapper queryWrapper) { applyCategoryPermission(queryWrapper); + applyPublishedOnlyFilter(queryWrapper); Page result = super.queryPage(page, queryWrapper); + if (isPublishedOnlyRequest()) { + result.setRecords(result.getRecords().stream().map(botService::toPublishedView).toList()); + } aiResourceApprovalStateService.fillBotApprovalState(result.getRecords()); aiResourceCreatorNameSupport.fillBotCreatorNames(result.getRecords()); return result; @@ -407,7 +458,9 @@ public class BotController extends BaseCurdController { return result; } - private ChatRuntimeContext buildRuntimeContext(Bot bot, BigInteger conversationId, String prompt, List attachments) { + private ChatRuntimeContext buildRuntimeContext(Bot bot, BigInteger conversationId, String prompt, List attachments, + List extraKnowledgeIds, + BigInteger regenerateRoundId) { LoginAccount account = requireCurrentLoginAccount(); ChatRuntimeContext context = new ChatRuntimeContext(); context.setChannel(ChatChannel.ADMIN); @@ -422,10 +475,30 @@ public class BotController extends BaseCurdController { context.setAssistantName(bot == null ? null : bot.getTitle()); context.setSessionTitle(prompt.length() > 200 ? prompt.substring(0, 200) : prompt); context.setAttachments(attachments); + if (extraKnowledgeIds != null) { + context.getExt().put(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS, extraKnowledgeIds); + } + if (regenerateRoundId != null) { + context.getExt().put(ChatRuntimeExtKeys.REGENERATE_ROUND_ID, regenerateRoundId); + } ChatTimeToolAvailabilityContext.bindLoggedInSnapshot(context, account, bot); return context; } + private void applyPublishedOnlyFilter(QueryWrapper queryWrapper) { + if (isPublishedOnlyRequest()) { + queryWrapper.eq("publish_status", PublishStatus.PUBLISHED.getCode()); + } + } + + private boolean isPublishedOnlyRequest() { + ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (attributes == null) { + return false; + } + return "true".equalsIgnoreCase(attributes.getRequest().getParameter("publishedOnly")); + } + private LoginAccount requireCurrentLoginAccount() { try { return SaTokenUtil.getLoginAccount(); diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/ChatWorkspaceController.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/ChatWorkspaceController.java new file mode 100644 index 0000000..5b9d496 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/controller/ai/ChatWorkspaceController.java @@ -0,0 +1,86 @@ +package tech.easyflow.admin.controller.ai; + +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceConversationView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceSessionDetailView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceSessionPage; +import tech.easyflow.admin.service.ai.ChatWorkspaceService; +import tech.easyflow.chatlog.domain.dto.ChatHistoryPage; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.query.ChatPageQuery; +import tech.easyflow.common.domain.Result; +import tech.easyflow.common.entity.LoginAccount; +import tech.easyflow.common.satoken.util.SaTokenUtil; +import tech.easyflow.common.web.jsonbody.JsonBody; + +import java.math.BigInteger; +import java.util.List; +import java.util.Map; + +/** + * 管理端聊天工作台控制器。 + */ +@RestController +@RequestMapping("/api/v1/chatWorkspace") +public class ChatWorkspaceController { + + private final ChatWorkspaceService chatWorkspaceService; + + public ChatWorkspaceController(ChatWorkspaceService chatWorkspaceService) { + this.chatWorkspaceService = chatWorkspaceService; + } + + @GetMapping("/sessions") + public Result listSessions(BigInteger assistantId, ChatPageQuery query) { + return Result.ok(chatWorkspaceService.queryCurrentUserSessions(currentAccount(), assistantId, query)); + } + + @GetMapping("/sessions/{sessionId}") + public Result getSession(@PathVariable BigInteger sessionId) { + return Result.ok(chatWorkspaceService.getCurrentUserSession(currentAccount(), sessionId)); + } + + @GetMapping("/sessions/{sessionId}/messages") + public Result queryMessages(@PathVariable BigInteger sessionId, ChatPageQuery query) { + return Result.ok(chatWorkspaceService.queryCurrentUserMessages(currentAccount(), sessionId, query)); + } + + @GetMapping("/sessions/{sessionId}/conversation") + public Result getConversation(@PathVariable BigInteger sessionId) { + return Result.ok(chatWorkspaceService.getCurrentUserConversation(currentAccount(), sessionId)); + } + + @PostMapping("/sessions/{sessionId}/rename") + public Result renameSession(@PathVariable BigInteger sessionId, + @JsonBody(value = "title", required = true) String title) { + chatWorkspaceService.renameCurrentUserSession(currentAccount(), sessionId, title); + return Result.ok(); + } + + @PostMapping("/sessions/{sessionId}/delete") + public Result deleteSession(@PathVariable BigInteger sessionId) { + chatWorkspaceService.deleteCurrentUserSession(currentAccount(), sessionId); + return Result.ok(); + } + + @GetMapping("/sessions/{sessionId}/rounds/{roundId}/variants") + public Result> listRoundVariants(@PathVariable BigInteger sessionId, + @PathVariable BigInteger roundId) { + return Result.ok(chatWorkspaceService.listCurrentUserRoundVariants(currentAccount(), sessionId, roundId)); + } + + @PostMapping("/sessions/{sessionId}/rounds/{roundId}/selectVariant") + public Result selectRoundVariant(@PathVariable BigInteger sessionId, + @PathVariable BigInteger roundId, + @JsonBody(value = "variantIndex", required = true) Integer variantIndex) { + return Result.ok(chatWorkspaceService.selectCurrentUserRoundVariant(currentAccount(), sessionId, roundId, variantIndex)); + } + + private LoginAccount currentAccount() { + return SaTokenUtil.getLoginAccount(); + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceAssistantView.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceAssistantView.java new file mode 100644 index 0000000..5c35719 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceAssistantView.java @@ -0,0 +1,56 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import java.io.Serializable; +import java.math.BigInteger; + +/** + * 工作台助手展示快照。 + */ +public class ChatWorkspaceAssistantView implements Serializable { + + private BigInteger id; + private String alias; + private String title; + private String description; + private String icon; + + public BigInteger getId() { + return id; + } + + public void setId(BigInteger id) { + this.id = id; + } + + public String getAlias() { + return alias; + } + + public void setAlias(String alias) { + this.alias = alias; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getIcon() { + return icon; + } + + public void setIcon(String icon) { + this.icon = icon; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceConversationView.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceConversationView.java new file mode 100644 index 0000000..9c0d16e --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceConversationView.java @@ -0,0 +1,73 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * 管理端聊天工作台完整会话视图。 + */ +public class ChatWorkspaceConversationView implements Serializable { + + private long total; + private List records = new ArrayList<>(); + private Map> variantsByRound = new LinkedHashMap<>(); + + /** + * 获取当前主线可见消息数量。 + * + * @return 主线消息数量 + */ + public long getTotal() { + return total; + } + + /** + * 设置当前主线可见消息数量。 + * + * @param total 主线消息数量 + */ + public void setTotal(long total) { + this.total = total; + } + + /** + * 获取当前主线可见消息。 + * + * @return 当前主线可见消息 + */ + public List getRecords() { + return records; + } + + /** + * 设置当前主线可见消息。 + * + * @param records 当前主线可见消息 + */ + public void setRecords(List records) { + this.records = records == null ? new ArrayList<>() : records; + } + + /** + * 获取按轮次分组的全部答案版本。 + * + * @return roundId 到答案版本列表的映射 + */ + public Map> getVariantsByRound() { + return variantsByRound; + } + + /** + * 设置按轮次分组的全部答案版本。 + * + * @param variantsByRound roundId 到答案版本列表的映射 + */ + public void setVariantsByRound(Map> variantsByRound) { + this.variantsByRound = variantsByRound == null ? new LinkedHashMap<>() : variantsByRound; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceKnowledgeView.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceKnowledgeView.java new file mode 100644 index 0000000..d1b5871 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceKnowledgeView.java @@ -0,0 +1,56 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import java.io.Serializable; +import java.math.BigInteger; + +/** + * 工作台知识库展示对象。 + */ +public class ChatWorkspaceKnowledgeView implements Serializable { + + private BigInteger id; + private String alias; + private String title; + private String description; + private String icon; + + public BigInteger getId() { + return id; + } + + public void setId(BigInteger id) { + this.id = id; + } + + public String getAlias() { + return alias; + } + + public void setAlias(String alias) { + this.alias = alias; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public String getIcon() { + return icon; + } + + public void setIcon(String icon) { + this.icon = icon; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceReadOnlyReason.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceReadOnlyReason.java new file mode 100644 index 0000000..ecfc4b4 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceReadOnlyReason.java @@ -0,0 +1,10 @@ +package tech.easyflow.admin.dto.chatworkspace; + +/** + * 管理端聊天工作台只读原因。 + */ +public enum ChatWorkspaceReadOnlyReason { + ASSISTANT_OFFLINE, + ASSISTANT_DELETED, + NO_PERMISSION +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionDetailView.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionDetailView.java new file mode 100644 index 0000000..8bbe1bd --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionDetailView.java @@ -0,0 +1,48 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * 工作台会话详情。 + */ +public class ChatWorkspaceSessionDetailView extends ChatWorkspaceSessionView implements Serializable { + + private ChatWorkspaceAssistantView assistant; + private List boundKnowledges = new ArrayList<>(); + private List extraKnowledges = new ArrayList<>(); + private List removedExtraKnowledgeNames = new ArrayList<>(); + + public ChatWorkspaceAssistantView getAssistant() { + return assistant; + } + + public void setAssistant(ChatWorkspaceAssistantView assistant) { + this.assistant = assistant; + } + + public List getBoundKnowledges() { + return boundKnowledges; + } + + public void setBoundKnowledges(List boundKnowledges) { + this.boundKnowledges = boundKnowledges == null ? new ArrayList<>() : boundKnowledges; + } + + public List getExtraKnowledges() { + return extraKnowledges; + } + + public void setExtraKnowledges(List extraKnowledges) { + this.extraKnowledges = extraKnowledges == null ? new ArrayList<>() : extraKnowledges; + } + + public List getRemovedExtraKnowledgeNames() { + return removedExtraKnowledgeNames; + } + + public void setRemovedExtraKnowledgeNames(List removedExtraKnowledgeNames) { + this.removedExtraKnowledgeNames = removedExtraKnowledgeNames == null ? new ArrayList<>() : removedExtraKnowledgeNames; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionPage.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionPage.java new file mode 100644 index 0000000..ce5d0b7 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionPage.java @@ -0,0 +1,48 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * 工作台会话分页结果。 + */ +public class ChatWorkspaceSessionPage implements Serializable { + + private Long total; + private Long pageNumber; + private Long pageSize; + private List records = new ArrayList<>(); + + public Long getTotal() { + return total; + } + + public void setTotal(Long total) { + this.total = total; + } + + public Long getPageNumber() { + return pageNumber; + } + + public void setPageNumber(Long pageNumber) { + this.pageNumber = pageNumber; + } + + public Long getPageSize() { + return pageSize; + } + + public void setPageSize(Long pageSize) { + this.pageSize = pageSize; + } + + public List getRecords() { + return records; + } + + public void setRecords(List records) { + this.records = records == null ? new ArrayList<>() : records; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionView.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionView.java new file mode 100644 index 0000000..a343ca6 --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/dto/chatworkspace/ChatWorkspaceSessionView.java @@ -0,0 +1,111 @@ +package tech.easyflow.admin.dto.chatworkspace; + +import java.io.Serializable; +import java.math.BigInteger; +import java.util.Date; + +/** + * 工作台会话摘要。 + */ +public class ChatWorkspaceSessionView implements Serializable { + + private BigInteger sessionId; + private BigInteger assistantId; + private String assistantCode; + private String assistantName; + private String title; + private String lastMessagePreview; + private Integer messageCount; + private Date accessAt; + private Date lastMessageAt; + private Boolean continuable; + private ChatWorkspaceReadOnlyReason readOnlyReason; + + public BigInteger getSessionId() { + return sessionId; + } + + public void setSessionId(BigInteger sessionId) { + this.sessionId = sessionId; + } + + public BigInteger getAssistantId() { + return assistantId; + } + + public void setAssistantId(BigInteger assistantId) { + this.assistantId = assistantId; + } + + public String getAssistantCode() { + return assistantCode; + } + + public void setAssistantCode(String assistantCode) { + this.assistantCode = assistantCode; + } + + public String getAssistantName() { + return assistantName; + } + + public void setAssistantName(String assistantName) { + this.assistantName = assistantName; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getLastMessagePreview() { + return lastMessagePreview; + } + + public void setLastMessagePreview(String lastMessagePreview) { + this.lastMessagePreview = lastMessagePreview; + } + + public Integer getMessageCount() { + return messageCount; + } + + public void setMessageCount(Integer messageCount) { + this.messageCount = messageCount; + } + + public Date getAccessAt() { + return accessAt; + } + + public void setAccessAt(Date accessAt) { + this.accessAt = accessAt; + } + + public Date getLastMessageAt() { + return lastMessageAt; + } + + public void setLastMessageAt(Date lastMessageAt) { + this.lastMessageAt = lastMessageAt; + } + + public Boolean getContinuable() { + return continuable; + } + + public void setContinuable(Boolean continuable) { + this.continuable = continuable; + } + + public ChatWorkspaceReadOnlyReason getReadOnlyReason() { + return readOnlyReason; + } + + public void setReadOnlyReason(ChatWorkspaceReadOnlyReason readOnlyReason) { + this.readOnlyReason = readOnlyReason; + } +} diff --git a/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/service/ai/ChatWorkspaceService.java b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/service/ai/ChatWorkspaceService.java new file mode 100644 index 0000000..7deccce --- /dev/null +++ b/easyflow-api/easyflow-api-admin/src/main/java/tech/easyflow/admin/service/ai/ChatWorkspaceService.java @@ -0,0 +1,515 @@ +package tech.easyflow.admin.service.ai; + +import com.mybatisflex.core.query.QueryWrapper; +import org.springframework.stereotype.Service; +import org.springframework.util.StringUtils; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceAssistantView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceConversationView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceKnowledgeView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceReadOnlyReason; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceSessionDetailView; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceSessionPage; +import tech.easyflow.admin.dto.chatworkspace.ChatWorkspaceSessionView; +import tech.easyflow.ai.entity.Bot; +import tech.easyflow.ai.entity.DocumentCollection; +import tech.easyflow.ai.enums.PublishStatus; +import tech.easyflow.ai.permission.KnowledgeReadAccessSnapshot; +import tech.easyflow.ai.permission.KnowledgeVisibilityQueryHelper; +import tech.easyflow.ai.service.BotService; +import tech.easyflow.ai.service.DocumentCollectionService; +import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatHistoryPage; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatSessionExtPayload; +import tech.easyflow.chatlog.domain.dto.ChatSessionPage; +import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; +import tech.easyflow.chatlog.domain.query.ChatPageQuery; +import tech.easyflow.chatlog.service.ChatRoundOperateService; +import tech.easyflow.chatlog.service.ChatSessionCommandService; +import tech.easyflow.chatlog.service.ChatSessionQueryService; +import tech.easyflow.chatlog.support.ChatJsonSupport; +import tech.easyflow.common.entity.LoginAccount; +import tech.easyflow.common.web.exceptions.BusinessException; +import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot; +import tech.easyflow.system.service.CategoryPermissionService; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static tech.easyflow.ai.entity.table.BotTableDef.BOT; + +/** + * 管理端聊天工作台服务。 + */ +@Service +public class ChatWorkspaceService { + + private final ChatSessionQueryService chatSessionQueryService; + private final ChatSessionCommandService chatSessionCommandService; + private final ChatRoundOperateService chatRoundOperateService; + private final BotService botService; + private final DocumentCollectionService documentCollectionService; + private final CategoryPermissionService categoryPermissionService; + private final KnowledgeVisibilityQueryHelper knowledgeVisibilityQueryHelper; + private final ChatJsonSupport chatJsonSupport; + + public ChatWorkspaceService(ChatSessionQueryService chatSessionQueryService, + ChatSessionCommandService chatSessionCommandService, + ChatRoundOperateService chatRoundOperateService, + BotService botService, + DocumentCollectionService documentCollectionService, + CategoryPermissionService categoryPermissionService, + KnowledgeVisibilityQueryHelper knowledgeVisibilityQueryHelper, + ChatJsonSupport chatJsonSupport) { + this.chatSessionQueryService = chatSessionQueryService; + this.chatSessionCommandService = chatSessionCommandService; + this.chatRoundOperateService = chatRoundOperateService; + this.botService = botService; + this.documentCollectionService = documentCollectionService; + this.categoryPermissionService = categoryPermissionService; + this.knowledgeVisibilityQueryHelper = knowledgeVisibilityQueryHelper; + this.chatJsonSupport = chatJsonSupport; + } + + /** + * 查询当前用户会话分页。 + * + * @param account 当前登录用户 + * @param assistantId 助手过滤条件 + * @param query 分页参数 + * @return 工作台会话分页 + */ + public ChatWorkspaceSessionPage queryCurrentUserSessions(LoginAccount account, BigInteger assistantId, ChatPageQuery query) { + ChatSessionPage page = chatSessionQueryService.pageSessions(account.getId(), assistantId, query); + Map availabilityMap = resolveAssistantAvailability(account, page.getRecords()); + ChatWorkspaceSessionPage result = new ChatWorkspaceSessionPage(); + result.setTotal(page.getTotal()); + result.setPageNumber(page.getPageNumber()); + result.setPageSize(page.getPageSize()); + List records = new ArrayList<>(); + for (ChatSessionSummary summary : page.getRecords()) { + records.add(toSessionView(summary, availabilityMap.get(summary.getAssistantId()))); + } + result.setRecords(records); + return result; + } + + /** + * 查询当前用户会话详情。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + * @return 工作台会话详情 + */ + public ChatWorkspaceSessionDetailView getCurrentUserSession(LoginAccount account, BigInteger sessionId) { + ChatSessionSummary summary = requireUserSession(account, sessionId); + AssistantAvailability availability = resolveAssistantAvailability(account, List.of(summary)).get(summary.getAssistantId()); + ChatWorkspaceSessionDetailView detail = new ChatWorkspaceSessionDetailView(); + fillSessionView(detail, summary, availability); + if (availability != null && availability.displayBot() != null) { + detail.setAssistant(toAssistantView(availability.displayBot(), summary)); + detail.setBoundKnowledges(resolveBoundKnowledges(availability.displayBot())); + } else { + detail.setAssistant(toAssistantView(null, summary)); + } + ExtraKnowledgeResolution extraKnowledgeResolution = resolveExtraKnowledges(summary); + detail.setExtraKnowledges(extraKnowledgeResolution.validKnowledges()); + detail.setRemovedExtraKnowledgeNames(extraKnowledgeResolution.removedNames()); + if (extraKnowledgeResolution.shouldSync()) { + syncSessionExtraKnowledges(summary, extraKnowledgeResolution.validKnowledgeIds(), account.getId()); + } + return detail; + } + + /** + * 查询当前用户会话消息。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + * @param query 分页参数 + * @return 历史消息分页 + */ + public ChatHistoryPage queryCurrentUserMessages(LoginAccount account, BigInteger sessionId, ChatPageQuery query) { + ChatSessionSummary summary = requireUserSession(account, sessionId); + ChatHistoryPage firstPage = restoreRecentMessages(summary, query); + if (firstPage != null) { + return firstPage; + } + return chatSessionQueryService.pageMainlineMessages(sessionId, query); + } + + /** + * 查询当前用户完整工作台会话。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + * @return 完整会话视图 + */ + public ChatWorkspaceConversationView getCurrentUserConversation(LoginAccount account, BigInteger sessionId) { + requireUserSession(account, sessionId); + List records = chatSessionQueryService.listMainlineMessages(sessionId); + Map> variantsByRound = new LinkedHashMap<>(); + Set roundIds = new LinkedHashSet<>(); + for (ChatMessageRecord record : records) { + if (record == null || record.getRoundId() == null) { + continue; + } + Integer variantCount = record.getVariantCount(); + if (variantCount != null && variantCount > 1) { + roundIds.add(record.getRoundId()); + } + } + for (BigInteger roundId : roundIds) { + variantsByRound.put(roundId.toString(), chatRoundOperateService.listVariants(sessionId, roundId)); + } + ChatWorkspaceConversationView view = new ChatWorkspaceConversationView(); + view.setRecords(records); + view.setVariantsByRound(variantsByRound); + view.setTotal(records.size()); + return view; + } + + /** + * 重命名当前用户会话。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + * @param title 新标题 + */ + public void renameCurrentUserSession(LoginAccount account, BigInteger sessionId, String title) { + if (!StringUtils.hasText(title)) { + throw new BusinessException("标题不能为空"); + } + requireUserSession(account, sessionId); + chatSessionCommandService.renameSession(sessionId, account.getId(), title.trim(), account.getId()); + } + + /** + * 删除当前用户会话。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + */ + public void deleteCurrentUserSession(LoginAccount account, BigInteger sessionId) { + requireUserSession(account, sessionId); + chatSessionCommandService.deleteSession(sessionId, account.getId(), account.getId()); + } + + public List listCurrentUserRoundVariants(LoginAccount account, BigInteger sessionId, BigInteger roundId) { + requireUserSession(account, sessionId); + return chatRoundOperateService.listVariants(sessionId, roundId); + } + + public ChatMessageRecord selectCurrentUserRoundVariant(LoginAccount account, BigInteger sessionId, BigInteger roundId, Integer variantIndex) { + requireUserSession(account, sessionId); + return chatRoundOperateService.selectVariant(sessionId, roundId, variantIndex, account.getId()); + } + + /** + * 发送前校验会话是否仍可继续聊天。 + * + * @param account 当前登录用户 + * @param sessionId 会话 ID + * @param requestBotId 本次请求助手 ID + */ + public void assertSessionContinuable(LoginAccount account, BigInteger sessionId, BigInteger requestBotId) { + ChatSessionSummary summary = chatSessionQueryService.getSessionSummary(sessionId); + if (summary == null || Integer.valueOf(1).equals(summary.getIsDeleted())) { + return; + } + if (!Objects.equals(summary.getUserId(), account.getId())) { + throw new BusinessException("无权访问该会话"); + } + if (requestBotId != null && summary.getAssistantId() != null && !Objects.equals(summary.getAssistantId(), requestBotId)) { + throw new BusinessException("当前会话与所选聊天助手不匹配"); + } + AssistantAvailability availability = resolveAssistantAvailability(account, List.of(summary)).get(summary.getAssistantId()); + if (availability == null || !availability.continuable()) { + throw new BusinessException(buildReadOnlyMessage(availability == null ? ChatWorkspaceReadOnlyReason.ASSISTANT_DELETED : availability.reason())); + } + } + + private ChatSessionSummary requireUserSession(LoginAccount account, BigInteger sessionId) { + ChatSessionSummary summary = chatSessionQueryService.getSessionSummary(sessionId); + if (summary == null || Integer.valueOf(1).equals(summary.getIsDeleted())) { + throw new BusinessException("会话不存在"); + } + if (!Objects.equals(summary.getUserId(), account.getId())) { + throw new BusinessException("无权访问该会话"); + } + return summary; + } + + /** + * 首屏优先从热态恢复最近消息,避免分析库延迟导致刚完成的回复不可见。 + * + * @param summary 会话摘要 + * @param query 分页参数 + * @return 命中热态时返回恢复结果,否则返回 null 继续走历史库 + */ + private ChatHistoryPage restoreRecentMessages(ChatSessionSummary summary, ChatPageQuery query) { + if (summary == null || query == null || query.getPageNumber() != 1) { + return null; + } + List records = + chatSessionQueryService.getRecentTail(summary.getId(), Math.toIntExact(query.getPageSize())); + if (records == null || records.isEmpty()) { + return null; + } + if (!isRestoredTailReliable(records)) { + return null; + } + ChatHistoryPage page = new ChatHistoryPage(); + page.setPageNumber(query.getPageNumber()); + page.setPageSize(query.getPageSize()); + page.setRecords(records); + long total = summary.getMessageCount() == null ? 0L : summary.getMessageCount(); + page.setTotal(Math.max(total, records.size())); + return page; + } + + /** + * 校验 Redis tail 是否仍符合当前主线版本语义。 + * + * @param records Redis tail 消息 + * @return true 表示可直接用于首屏恢复 + */ + private boolean isRestoredTailReliable(List records) { + Map selectedVariantByRound = new LinkedHashMap<>(); + Map> assistantVariantsByRound = new LinkedHashMap<>(); + for (ChatMessageRecord record : records) { + if (record == null || record.getRoundId() == null) { + continue; + } + Integer selectedVariantIndex = record.getSelectedVariantIndex(); + if (selectedVariantIndex != null && selectedVariantIndex > 0) { + Integer previous = selectedVariantByRound.putIfAbsent(record.getRoundId(), selectedVariantIndex); + if (previous != null && !Objects.equals(previous, selectedVariantIndex)) { + return false; + } + } + if ("assistant".equalsIgnoreCase(record.getSenderRole()) + && record.getVariantIndex() != null + && record.getVariantIndex() > 0) { + assistantVariantsByRound + .computeIfAbsent(record.getRoundId(), key -> new LinkedHashSet<>()) + .add(record.getVariantIndex()); + } + } + for (Map.Entry entry : selectedVariantByRound.entrySet()) { + Set visibleVariants = assistantVariantsByRound.get(entry.getKey()); + if (visibleVariants != null && !visibleVariants.isEmpty() && !visibleVariants.contains(entry.getValue())) { + return false; + } + } + return true; + } + + private Map resolveAssistantAvailability(LoginAccount account, List sessions) { + Map result = new LinkedHashMap<>(); + if (sessions == null || sessions.isEmpty()) { + return result; + } + Set assistantIds = new LinkedHashSet<>(); + for (ChatSessionSummary session : sessions) { + if (session != null && session.getAssistantId() != null) { + assistantIds.add(session.getAssistantId()); + } + } + if (assistantIds.isEmpty()) { + return result; + } + List bots = botService.list(QueryWrapper.create().where(BOT.ID.in(assistantIds))); + Map botMap = new LinkedHashMap<>(); + for (Bot bot : bots) { + botMap.put(bot.getId(), bot); + } + RoleCategoryAccessSnapshot accessSnapshot = categoryPermissionService.getAccess("BOT", account); + for (BigInteger assistantId : assistantIds) { + Bot currentBot = botMap.get(assistantId); + if (currentBot == null) { + result.put(assistantId, new AssistantAvailability(false, ChatWorkspaceReadOnlyReason.ASSISTANT_DELETED, null)); + continue; + } + if (!accessSnapshot.canAccess(currentBot.getCreatedBy(), currentBot.getCategoryId())) { + result.put(assistantId, new AssistantAvailability(false, ChatWorkspaceReadOnlyReason.NO_PERMISSION, null)); + continue; + } + Bot displayBot = botService.toPublishedView(currentBot); + boolean online = Integer.valueOf(1).equals(currentBot.getStatus()) + && PublishStatus.from(currentBot.getPublishStatus()) == PublishStatus.PUBLISHED; + result.put(assistantId, new AssistantAvailability( + online, + online ? null : ChatWorkspaceReadOnlyReason.ASSISTANT_OFFLINE, + displayBot + )); + } + return result; + } + + private ChatWorkspaceSessionView toSessionView(ChatSessionSummary summary, AssistantAvailability availability) { + ChatWorkspaceSessionView view = new ChatWorkspaceSessionView(); + fillSessionView(view, summary, availability); + return view; + } + + private void fillSessionView(ChatWorkspaceSessionView view, ChatSessionSummary summary, AssistantAvailability availability) { + view.setSessionId(summary.getId()); + view.setAssistantId(summary.getAssistantId()); + view.setAssistantCode(summary.getAssistantCode()); + view.setAssistantName(summary.getAssistantName()); + view.setTitle(summary.getTitle()); + view.setLastMessagePreview(summary.getLastMessagePreview()); + view.setMessageCount(summary.getMessageCount()); + view.setAccessAt(summary.getAccessAt()); + view.setLastMessageAt(summary.getLastMessageAt()); + view.setContinuable(availability != null && availability.continuable()); + view.setReadOnlyReason(availability == null ? ChatWorkspaceReadOnlyReason.ASSISTANT_DELETED : availability.reason()); + } + + private ChatWorkspaceAssistantView toAssistantView(Bot bot, ChatSessionSummary summary) { + ChatWorkspaceAssistantView view = new ChatWorkspaceAssistantView(); + if (bot != null) { + view.setId(bot.getId()); + view.setAlias(bot.getAlias()); + view.setTitle(bot.getTitle()); + view.setDescription(bot.getDescription()); + view.setIcon(bot.getIcon()); + return view; + } + view.setId(summary == null ? null : summary.getAssistantId()); + view.setAlias(summary == null ? null : summary.getAssistantCode()); + view.setTitle(summary == null ? null : summary.getAssistantName()); + return view; + } + + private List resolveBoundKnowledges(Bot displayBot) { + if (displayBot == null || displayBot.getPublishedSnapshotJson() == null) { + return List.of(); + } + Object rawBindings = displayBot.getPublishedSnapshotJson().get("knowledgeBindings"); + if (!(rawBindings instanceof List bindings) || bindings.isEmpty()) { + return List.of(); + } + List knowledgeIds = new ArrayList<>(); + for (Object binding : bindings) { + if (!(binding instanceof Map bindingMap) || bindingMap.get("knowledgeId") == null) { + continue; + } + knowledgeIds.add(new BigInteger(String.valueOf(bindingMap.get("knowledgeId")))); + } + return resolveVisibleKnowledgeViews(knowledgeIds).validKnowledges(); + } + + private ExtraKnowledgeResolution resolveExtraKnowledges(ChatSessionSummary summary) { + ChatSessionExtPayload payload = chatJsonSupport.fromJson(summary.getExtJson(), ChatSessionExtPayload.class); + List extraKnowledgeIds = payload == null ? List.of() : payload.getExtraKnowledgeIds(); + return resolveVisibleKnowledgeViews(extraKnowledgeIds); + } + + private ExtraKnowledgeResolution resolveVisibleKnowledgeViews(List knowledgeIds) { + if (knowledgeIds == null || knowledgeIds.isEmpty()) { + return new ExtraKnowledgeResolution(List.of(), List.of(), List.of(), false); + } + List normalizedIds = new ArrayList<>(); + for (BigInteger knowledgeId : knowledgeIds) { + if (knowledgeId != null && !normalizedIds.contains(knowledgeId)) { + normalizedIds.add(knowledgeId); + } + } + if (normalizedIds.isEmpty()) { + return new ExtraKnowledgeResolution(List.of(), List.of(), List.of(), false); + } + List collections = documentCollectionService.listByIds(normalizedIds); + Map collectionMap = new LinkedHashMap<>(); + for (DocumentCollection collection : collections) { + collectionMap.put(collection.getId(), collection); + } + KnowledgeReadAccessSnapshot accessSnapshot = knowledgeVisibilityQueryHelper.getCurrentReadSnapshot(); + List validKnowledges = new ArrayList<>(); + List validKnowledgeIds = new ArrayList<>(); + List removedNames = new ArrayList<>(); + boolean changed = false; + for (BigInteger knowledgeId : normalizedIds) { + DocumentCollection current = collectionMap.get(knowledgeId); + if (current == null) { + removedNames.add("知识库#" + knowledgeId); + changed = true; + continue; + } + if (PublishStatus.from(current.getPublishStatus()) != PublishStatus.PUBLISHED) { + removedNames.add(current.getTitle()); + changed = true; + continue; + } + if (!knowledgeVisibilityQueryHelper.canRead(current, accessSnapshot)) { + removedNames.add(current.getTitle()); + changed = true; + continue; + } + validKnowledges.add(toKnowledgeView(documentCollectionService.toPublishedView(current))); + validKnowledgeIds.add(current.getId()); + } + if (!Objects.equals(normalizedIds, validKnowledgeIds)) { + changed = true; + } + return new ExtraKnowledgeResolution(validKnowledges, validKnowledgeIds, removedNames, changed); + } + + private ChatWorkspaceKnowledgeView toKnowledgeView(DocumentCollection collection) { + ChatWorkspaceKnowledgeView view = new ChatWorkspaceKnowledgeView(); + view.setId(collection.getId()); + view.setAlias(collection.getAlias()); + view.setTitle(collection.getTitle()); + view.setDescription(collection.getDescription()); + view.setIcon(collection.getIcon()); + return view; + } + + private void syncSessionExtraKnowledges(ChatSessionSummary summary, List validKnowledgeIds, BigInteger operatorId) { + ChatSessionExtPayload payload = new ChatSessionExtPayload(); + payload.setExtraKnowledgeIds(validKnowledgeIds); + ChatSessionUpsertCommand command = new ChatSessionUpsertCommand(); + command.setSessionId(summary.getId()); + command.setTenantId(summary.getTenantId()); + command.setDeptId(summary.getDeptId()); + command.setUserId(summary.getUserId()); + command.setUserAccount(summary.getUserAccount()); + command.setAssistantId(summary.getAssistantId()); + command.setAssistantCode(summary.getAssistantCode()); + command.setAssistantName(summary.getAssistantName()); + command.setTitle(summary.getTitle()); + command.setExtJson(chatJsonSupport.toJson(payload)); + command.setOperatorId(operatorId); + command.setOperateAt(new Date()); + chatSessionCommandService.createOrTouchSession(command); + } + + private String buildReadOnlyMessage(ChatWorkspaceReadOnlyReason reason) { + if (reason == ChatWorkspaceReadOnlyReason.NO_PERMISSION) { + return "当前会话对应的聊天助手已无权限访问,仅支持查看历史记录"; + } + if (reason == ChatWorkspaceReadOnlyReason.ASSISTANT_OFFLINE) { + return "当前会话对应的聊天助手已下架,无法继续聊天"; + } + return "当前会话对应的聊天助手已删除,无法继续聊天"; + } + + private record AssistantAvailability(boolean continuable, + ChatWorkspaceReadOnlyReason reason, + Bot displayBot) { + } + + private record ExtraKnowledgeResolution(List validKnowledges, + List validKnowledgeIds, + List removedNames, + boolean shouldSync) { + } +} diff --git a/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeExtKeys.java b/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeExtKeys.java new file mode 100644 index 0000000..825f812 --- /dev/null +++ b/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeExtKeys.java @@ -0,0 +1,35 @@ +package tech.easyflow.core.runtime; + +/** + * 聊天运行时扩展字段键。 + */ +public final class ChatRuntimeExtKeys { + + /** + * 会话级额外知识库 ID 列表。 + */ + public static final String EXTRA_KNOWLEDGE_IDS = "extraKnowledgeIds"; + + /** + * 当前请求要重答的轮次 ID。 + */ + public static final String REGENERATE_ROUND_ID = "regenerateRoundId"; + + /** + * 当前请求归属的轮次 ID。 + */ + public static final String CURRENT_ROUND_ID = "currentRoundId"; + + /** + * 当前请求归属的轮次序号。 + */ + public static final String CURRENT_ROUND_NO = "currentRoundNo"; + + /** + * 当前请求生成的答案版本序号。 + */ + public static final String CURRENT_VARIANT_INDEX = "currentVariantIndex"; + + private ChatRuntimeExtKeys() { + } +} diff --git a/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeMessage.java b/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeMessage.java index 1cbb713..b7819ff 100644 --- a/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeMessage.java +++ b/easyflow-commons/easyflow-common-chat-protocol/src/main/java/tech/easyflow/core/runtime/ChatRuntimeMessage.java @@ -16,6 +16,10 @@ public class ChatRuntimeMessage implements Serializable { private Date createdAt = new Date(); private BigInteger senderId; private String senderName; + private BigInteger roundId; + private Integer roundNo; + private String messageKind; + private Integer variantIndex; public BigInteger getMessageId() { return messageId; @@ -80,4 +84,36 @@ public class ChatRuntimeMessage implements Serializable { public void setSenderName(String senderName) { this.senderName = senderName; } + + public BigInteger getRoundId() { + return roundId; + } + + public void setRoundId(BigInteger roundId) { + this.roundId = roundId; + } + + public Integer getRoundNo() { + return roundNo; + } + + public void setRoundNo(Integer roundNo) { + this.roundNo = roundNo; + } + + public String getMessageKind() { + return messageKind; + } + + public void setMessageKind(String messageKind) { + this.messageKind = messageKind; + } + + public Integer getVariantIndex() { + return variantIndex; + } + + public void setVariantIndex(Integer variantIndex) { + this.variantIndex = variantIndex; + } } diff --git a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/easyagents/listener/ChatStreamListener.java b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/easyagents/listener/ChatStreamListener.java index 318153b..035e771 100644 --- a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/easyagents/listener/ChatStreamListener.java +++ b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/easyagents/listener/ChatStreamListener.java @@ -20,9 +20,11 @@ import tech.easyflow.core.chat.protocol.payload.ErrorPayload; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.runtime.ChatAssistantAccumulator; import tech.easyflow.core.runtime.ChatRuntimeContext; +import tech.easyflow.core.runtime.ChatRuntimeExtKeys; import tech.easyflow.core.runtime.ChatRuntimeManager; import tech.easyflow.core.runtime.ChatRuntimeMessage; +import java.math.BigInteger; import java.util.Date; import java.util.LinkedHashMap; import java.util.List; @@ -176,6 +178,7 @@ public class ChatStreamListener implements StreamResponseListener { deltaMap.put("role", MessageRole.ASSISTANT.getValue()); deltaMap.put("delta", deltaContent); chatEnvelope.setPayload(deltaMap); + chatEnvelope.setMeta(buildStreamMeta()); boolean sent = sseEmitter.send(chatEnvelope); if (!sent) { @@ -196,6 +199,7 @@ public class ChatStreamListener implements StreamResponseListener { payload.put("name", toolCall.getName()); payload.put("arguments", toolCall.getArguments()); chatEnvelope.setPayload(payload); + chatEnvelope.setMeta(buildStreamMeta()); boolean sent = sseEmitter.send(chatEnvelope); if (!sent) { throw new IllegalStateException("SSE emitter has already completed while sending tool call envelope"); @@ -214,6 +218,7 @@ public class ChatStreamListener implements StreamResponseListener { payload.put("tool_call_id", toolMessage.getToolCallId()); payload.put("result", toolMessage.getContent()); chatEnvelope.setPayload(payload); + chatEnvelope.setMeta(buildStreamMeta()); boolean sent = sseEmitter.send(chatEnvelope); if (!sent) { throw new IllegalStateException("SSE emitter has already completed while sending tool result envelope"); @@ -236,6 +241,7 @@ public class ChatStreamListener implements StreamResponseListener { envelope.setPayload(payload); envelope.setDomain(ChatDomain.SYSTEM); envelope.setType(ChatType.ERROR); + envelope.setMeta(buildStreamMeta()); boolean sent = sseEmitter.sendError(envelope); if (!sent) { LOG.warn("sendSystemError skipped because emitter is closed, conversationId={}", conversationId); @@ -243,6 +249,54 @@ public class ChatStreamListener implements StreamResponseListener { sseEmitter.complete(); } + private Map buildStreamMeta() { + Map meta = new LinkedHashMap<>(); + BigInteger roundId = getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID); + Integer roundNo = getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO); + Integer variantIndex = getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX); + BigInteger regenerateRoundId = getBigIntegerExt(ChatRuntimeExtKeys.REGENERATE_ROUND_ID); + if (roundId != null) { + meta.put("roundId", roundId.toString()); + } + if (roundNo != null) { + meta.put("roundNo", roundNo); + } + if (variantIndex != null) { + meta.put("variantIndex", variantIndex); + } + meta.put("regenerate", regenerateRoundId != null); + if (regenerateRoundId != null) { + meta.put("regenerateRoundId", regenerateRoundId.toString()); + } + return meta; + } + + private BigInteger getBigIntegerExt(String key) { + Object value = runtimeContext == null || runtimeContext.getExt() == null + ? null + : runtimeContext.getExt().get(key); + if (value == null) { + return null; + } + if (value instanceof BigInteger number) { + return number; + } + return new BigInteger(String.valueOf(value)); + } + + private Integer getIntegerExt(String key) { + Object value = runtimeContext == null || runtimeContext.getExt() == null + ? null + : runtimeContext.getExt().get(key); + if (value == null) { + return null; + } + if (value instanceof Integer number) { + return number; + } + return Integer.parseInt(String.valueOf(value)); + } + private void stopStreamClient(StreamContext context, String reason, Throwable source) { try { if (context != null && context.getClient() != null) { @@ -267,6 +321,7 @@ public class ChatStreamListener implements StreamResponseListener { message.setCreatedAt(new Date()); message.setSenderId(runtimeContext.getAssistantId()); message.setSenderName(runtimeContext.getAssistantName()); + applyRoundMetadata(message); return message; } @@ -280,9 +335,24 @@ public class ChatStreamListener implements StreamResponseListener { message.setCreatedAt(new Date()); message.setSenderId(runtimeContext.getAssistantId()); message.setSenderName(runtimeContext.getAssistantName()); + applyRoundMetadata(message); return message; } + /** + * 把当前轮次元数据写回 assistant 消息,确保 SSE 与后续持久化链路都能识别轮次和版本。 + * + * @param message assistant 运行时消息 + */ + private void applyRoundMetadata(ChatRuntimeMessage message) { + if (message == null) { + return; + } + message.setRoundId(getBigIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_ID)); + message.setRoundNo(getIntegerExt(ChatRuntimeExtKeys.CURRENT_ROUND_NO)); + message.setVariantIndex(getIntegerExt(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX)); + } + /** * 在 tool call assistant 写入临时 memory 前,把 reasoning/content 快照回填到消息对象中, * 以便前端 history 透传和 DeepSeek 下一轮请求都能拿到完整链路。 diff --git a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/BotService.java b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/BotService.java index b0ad996..7b2ae1d 100644 --- a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/BotService.java +++ b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/BotService.java @@ -56,6 +56,25 @@ public interface BotService extends IService { SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, BotServiceImpl.ChatCheckResult chatCheckResult); + SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + BotServiceImpl.ChatCheckResult chatCheckResult, BigInteger regenerateRoundId); + + /** + * 聊天前置校验,并根据调用方要求决定是否强制走发布态。 + * + * @param botId 聊天助手 ID + * @param prompt 用户问题 + * @param conversationId 会话 ID + * @param chatCheckResult 校验结果承载对象 + * @param publishedOnly 是否强制走发布态 + * @return 校验失败时返回错误 SSE;成功时返回 {@code null} + */ + SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly); + + SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + BotServiceImpl.ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId); + SseEmitter startChat(BigInteger botId, String prompt, BigInteger conversationId, List> messages, BotServiceImpl.ChatCheckResult chatCheckResult, List attachments, ChatRuntimeContext runtimeContext); diff --git a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/impl/BotServiceImpl.java b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/impl/BotServiceImpl.java index ced955e..d60ba68 100644 --- a/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/impl/BotServiceImpl.java +++ b/easyflow-modules/easyflow-module-ai/src/main/java/tech/easyflow/ai/service/impl/BotServiceImpl.java @@ -13,6 +13,7 @@ import com.easyagents.core.model.chat.ChatOptions; import com.easyagents.core.model.chat.StreamResponseListener; import com.easyagents.core.model.chat.tool.Tool; import com.easyagents.core.prompt.MemoryPrompt; +import com.easyagents.rag.retrieval.RetrievalMode; import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.spring.service.impl.ServiceImpl; import org.slf4j.Logger; @@ -51,6 +52,7 @@ import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseUtil; import tech.easyflow.core.runtime.ChatAssistantAccumulator; +import tech.easyflow.core.runtime.ChatRuntimeExtKeys; import tech.easyflow.core.runtime.ChatRuntimeContext; import tech.easyflow.core.runtime.ChatRuntimeManager; import tech.easyflow.core.runtime.ChatRuntimeMessage; @@ -60,9 +62,11 @@ import javax.annotation.Resource; import java.math.BigInteger; import java.util.ArrayList; import java.util.Date; +import java.util.LinkedHashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static tech.easyflow.ai.entity.table.BotPluginTableDef.BOT_PLUGIN; import static tech.easyflow.ai.entity.table.PluginItemTableDef.PLUGIN_ITEM; @@ -78,6 +82,8 @@ public class BotServiceImpl extends ServiceImpl implements BotSe private static final Logger log = LoggerFactory.getLogger(BotServiceImpl.class); private static final String FAQ_IMAGE_SYSTEM_RULE = "当知识工具返回 Markdown 图片(格式:![描述](URL))时,你必须在最终回答中保留并输出对应的图片 Markdown,禁止改写、替换或省略图片 URL。"; + private static final String EXTRA_KNOWLEDGE_PRIORITY_RULE = "若当前会话显式选择了额外知识库,请优先参考这些额外知识库;仅在额外知识库无法回答时,再回退到助手默认绑定知识库。"; + private static final int MAX_EXTRA_KNOWLEDGE_COUNT = 3; public static class ChatCheckResult { private Bot aiBot; @@ -215,6 +221,24 @@ public class BotServiceImpl extends ServiceImpl implements BotSe } public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, ChatCheckResult chatCheckResult) { + return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, null); + } + + @Override + public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + ChatCheckResult chatCheckResult, BigInteger regenerateRoundId) { + return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, false, regenerateRoundId); + } + + @Override + public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + ChatCheckResult chatCheckResult, boolean publishedOnly) { + return checkChatBeforeStart(botId, prompt, conversationId, chatCheckResult, publishedOnly, null); + } + + @Override + public SseEmitter checkChatBeforeStart(BigInteger botId, String prompt, String conversationId, + ChatCheckResult chatCheckResult, boolean publishedOnly, BigInteger regenerateRoundId) { if (!StringUtils.hasLength(prompt)) { return ChatSseUtil.sendSystemError(conversationId, "提示词不能为空"); } @@ -248,7 +272,13 @@ public class BotServiceImpl extends ServiceImpl implements BotSe if ((!login || anonymousAccount) && !aiBot.isAnonymousEnabled()) { return ChatSseUtil.sendSystemError(conversationId, "此聊天助手不支持匿名访问"); } - if (!login || anonymousAccount) { + if (publishedOnly && login && !anonymousAccount) { + if (PublishStatus.from(aiBot.getPublishStatus()) != PublishStatus.PUBLISHED) { + return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布"); + } + aiBot = toPublishedView(aiBot); + chatCheckResult.setPublishedAccess(true); + } else if (!login || anonymousAccount) { Bot publishedBot = toPublishedView(aiBot); if (!PublishStatus.from(aiBot.getPublishStatus()).isExternallyVisible()) { return ChatSseUtil.sendSystemError(conversationId, "聊天助手尚未发布"); @@ -267,7 +297,6 @@ public class BotServiceImpl extends ServiceImpl implements BotSe if (chatModel == null) { return ChatSseUtil.sendSystemError(conversationId, "对话模型获取失败,请检查配置"); } - chatCheckResult.setAiBot(aiBot); chatCheckResult.setModelOptions(modelOptions); chatCheckResult.setChatModel(chatModel); @@ -282,8 +311,9 @@ public class BotServiceImpl extends ServiceImpl implements BotSe ChatModel chatModel = chatCheckResult.getChatModel(); ChatTimeToolAvailabilityContext chatTimeContext = buildChatTimeToolAvailabilityContext(runtimeContext, chatCheckResult.getAiBot()); final MemoryPrompt memoryPrompt = new MemoryPrompt(); - String systemPrompt = buildSystemPromptWithFaqImageRule( - MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT) + String systemPrompt = buildSystemPrompt( + MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT), + runtimeContext ); Integer maxMessageCount = MapUtil.getInteger(modelOptions, Bot.KEY_MAX_MESSAGE_COUNT); if (maxMessageCount != null) { @@ -301,7 +331,8 @@ public class BotServiceImpl extends ServiceImpl implements BotSe .set("needEnglishName", true) .set("bot", chatCheckResult.getAiBot()) .set("chatTimeContext", chatTimeContext) - .set("publishedOnly", chatCheckResult.isPublishedAccess()))); + .set("publishedOnly", chatCheckResult.isPublishedAccess()) + .set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext)))); ChatOptions chatOptions = getChatOptions(modelOptions); Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false); chatOptions.setThinkingEnabled(enableDeepThinking); @@ -353,8 +384,9 @@ public class BotServiceImpl extends ServiceImpl implements BotSe Boolean enableDeepThinking = MapUtil.getBoolean(modelOptions, Bot.KEY_ENABLE_DEEP_THINKING, false); chatOptions.setThinkingEnabled(enableDeepThinking); ChatModel chatModel = chatCheckResult.getChatModel(); - String systemPrompt = buildSystemPromptWithFaqImageRule( - MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT) + String systemPrompt = buildSystemPrompt( + MapUtil.getString(modelOptions, Bot.KEY_SYSTEM_PROMPT), + runtimeContext ); UserMessage userMessage = new UserMessage(prompt); userMessage.addTools(buildFunctionList(Maps.of("botId", botId) @@ -362,6 +394,7 @@ public class BotServiceImpl extends ServiceImpl implements BotSe .set("needAccountId", false) .set("bot", chatCheckResult.getAiBot()) .set("publishedOnly", chatCheckResult.isPublishedAccess()) + .set("extraKnowledgeIds", resolveExtraKnowledgeIds(runtimeContext)) )); ChatSseEmitter chatSseEmitter = new ChatSseEmitter(); SseEmitter emitter = chatSseEmitter.getEmitter(); @@ -474,15 +507,18 @@ public class BotServiceImpl extends ServiceImpl implements BotSe } Bot runtimeBot = (Bot) buildParams.get("bot"); ChatTimeToolAvailabilityContext chatTimeContext = (ChatTimeToolAvailabilityContext) buildParams.get("chatTimeContext"); + List extraKnowledgeIds = sanitizeExtraKnowledgeIds((List) buildParams.get("extraKnowledgeIds")); boolean usePublishedSnapshot = Boolean.TRUE.equals(buildParams.get("publishedOnly")) && runtimeBot != null && runtimeBot.getPublishedSnapshotJson() != null && PublishStatus.from(runtimeBot.getPublishStatus()).isExternallyVisible(); QueryWrapper queryWrapper = QueryWrapper.create(); + Set existingKnowledgeIds = new LinkedHashSet<>(); + appendExtraKnowledgeTools(functionList, extraKnowledgeIds, needEnglishName, chatTimeContext, existingKnowledgeIds); if (usePublishedSnapshot) { + appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext, existingKnowledgeIds); appendPublishedWorkflowTools(functionList, runtimeBot, needEnglishName); - appendPublishedKnowledgeTools(functionList, runtimeBot, needEnglishName, chatTimeContext); } else { // 工作流 function 集合 queryWrapper.eq(BotWorkflow::getBotId, botId); @@ -500,7 +536,7 @@ public class BotServiceImpl extends ServiceImpl implements BotSe queryWrapper.eq(BotDocumentCollection::getBotId, botId); List botDocumentCollections = botDocumentCollectionService.getMapper() .selectListWithRelationsByQuery(queryWrapper); - functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext)); + functionList.addAll(buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, existingKnowledgeIds)); } // 插件 function 集合 @@ -550,6 +586,22 @@ public class BotServiceImpl extends ServiceImpl implements BotSe List buildKnowledgeTools(List botDocumentCollections, boolean needEnglishName, ChatTimeToolAvailabilityContext chatTimeContext) { + return buildKnowledgeTools(botDocumentCollections, needEnglishName, chatTimeContext, new LinkedHashSet<>()); + } + + /** + * 将 Bot 绑定的知识库候选项收敛为当前聊天可用的工具列表,并与已选临时知识库去重。 + * + * @param botDocumentCollections Bot 知识库绑定项 + * @param needEnglishName 是否使用英文名称 + * @param chatTimeContext 聊天时权限上下文 + * @param existingKnowledgeIds 已装配知识库 ID 集 + * @return 知识库工具列表 + */ + List buildKnowledgeTools(List botDocumentCollections, + boolean needEnglishName, + ChatTimeToolAvailabilityContext chatTimeContext, + Set existingKnowledgeIds) { List functionList = new ArrayList<>(); if (botDocumentCollections == null || botDocumentCollections.isEmpty()) { return functionList; @@ -559,7 +611,7 @@ public class BotServiceImpl extends ServiceImpl implements BotSe : chatTimeToolAvailabilityService.filterAvailable(chatTimeContext, botDocumentCollections); for (BotDocumentCollection botDocumentCollection : availableBindings) { DocumentCollection knowledge = botDocumentCollection.getKnowledge(); - if (knowledge == null) { + if (knowledge == null || isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) { continue; } DocumentCollectionTool function = (DocumentCollectionTool) knowledge.toFunction( @@ -603,7 +655,8 @@ public class BotServiceImpl extends ServiceImpl implements BotSe private void appendPublishedKnowledgeTools(List functionList, Bot runtimeBot, boolean needEnglishName, - ChatTimeToolAvailabilityContext chatTimeContext) { + ChatTimeToolAvailabilityContext chatTimeContext, + Set existingKnowledgeIds) { Object knowledges = runtimeBot.getPublishedSnapshotJson().get("knowledgeBindings"); if (!(knowledges instanceof List knowledgeBindings)) { return; @@ -617,12 +670,15 @@ public class BotServiceImpl extends ServiceImpl implements BotSe continue; } DocumentCollection knowledge = documentCollectionService.getPublishedById(new BigInteger(String.valueOf(knowledgeId))); - if (knowledge == null) { + if (knowledge == null || PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) { continue; } if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) { continue; } + if (isDuplicateKnowledge(existingKnowledgeIds, knowledge.getId())) { + continue; + } Object retrievalMode = bindingMap.get("retrievalMode"); functionList.add(knowledge.toFunction( needEnglishName, @@ -632,6 +688,42 @@ public class BotServiceImpl extends ServiceImpl implements BotSe } } + /** + * 组装会话级临时知识库工具,并按用户选择顺序优先插入。 + * + * @param functionList 工具集合 + * @param extraKnowledgeIds 额外知识库 ID + * @param needEnglishName 是否使用英文名称 + * @param chatTimeContext 聊天时权限上下文 + * @param existingKnowledgeIds 已装配知识库 ID 集 + */ + protected void appendExtraKnowledgeTools(List functionList, + List extraKnowledgeIds, + boolean needEnglishName, + ChatTimeToolAvailabilityContext chatTimeContext, + Set existingKnowledgeIds) { + if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) { + return; + } + for (BigInteger knowledgeId : extraKnowledgeIds) { + if (knowledgeId == null || isDuplicateKnowledge(existingKnowledgeIds, knowledgeId)) { + continue; + } + DocumentCollection knowledge = documentCollectionService.getById(knowledgeId); + if (knowledge == null) { + throw new BusinessException("额外知识库不存在"); + } + if (PublishStatus.from(knowledge.getPublishStatus()) != PublishStatus.PUBLISHED) { + throw new BusinessException("额外知识库未发布,无法用于聊天"); + } + if (chatTimeContext != null && !chatTimeToolAvailabilityService.evaluate(chatTimeContext, knowledge).isAvailable()) { + throw new BusinessException("当前用户无权使用所选知识库"); + } + functionList.add(documentCollectionService.toPublishedView(knowledge) + .toFunction(needEnglishName, RetrievalMode.HYBRID.name(), chatTimeContext)); + } + } + /** * 构造聊天时工具可用性上下文,并显式回填到运行时上下文中供后续异步工具调用复用。 * @@ -704,14 +796,64 @@ public class BotServiceImpl extends ServiceImpl implements BotSe return context.getUserAccount(); } - private String buildSystemPromptWithFaqImageRule(String systemPrompt) { - if (!StringUtils.hasLength(systemPrompt)) { - return FAQ_IMAGE_SYSTEM_RULE; + private String buildSystemPrompt(String systemPrompt, ChatRuntimeContext runtimeContext) { + String mergedPrompt = appendPromptRule(systemPrompt, FAQ_IMAGE_SYSTEM_RULE); + if (hasExtraKnowledgeSelection(runtimeContext)) { + mergedPrompt = appendPromptRule(mergedPrompt, EXTRA_KNOWLEDGE_PRIORITY_RULE); } - if (systemPrompt.contains(FAQ_IMAGE_SYSTEM_RULE)) { + return mergedPrompt; + } + + private String appendPromptRule(String systemPrompt, String rule) { + if (!StringUtils.hasLength(systemPrompt)) { + return rule; + } + if (systemPrompt.contains(rule)) { return systemPrompt; } - return systemPrompt + "\n\n" + FAQ_IMAGE_SYSTEM_RULE; + return systemPrompt + "\n\n" + rule; + } + + private List resolveExtraKnowledgeIds(ChatRuntimeContext runtimeContext) { + if (runtimeContext == null || runtimeContext.getExt() == null) { + return List.of(); + } + Object rawValue = runtimeContext.getExt().get(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS); + if (!(rawValue instanceof List rawList) || rawList.isEmpty()) { + return List.of(); + } + List values = new ArrayList<>(rawList.size()); + for (Object item : rawList) { + if (item == null) { + continue; + } + values.add(new BigInteger(String.valueOf(item))); + } + return values; + } + + private List sanitizeExtraKnowledgeIds(List extraKnowledgeIds) { + if (extraKnowledgeIds == null || extraKnowledgeIds.isEmpty()) { + return List.of(); + } + LinkedHashSet dedup = new LinkedHashSet<>(); + for (BigInteger knowledgeId : extraKnowledgeIds) { + if (knowledgeId != null) { + dedup.add(knowledgeId); + } + } + if (dedup.size() > MAX_EXTRA_KNOWLEDGE_COUNT) { + throw new BusinessException("额外知识库最多选择 3 个"); + } + return new ArrayList<>(dedup); + } + + private boolean hasExtraKnowledgeSelection(ChatRuntimeContext runtimeContext) { + return !resolveExtraKnowledgeIds(runtimeContext).isEmpty(); + } + + private boolean isDuplicateKnowledge(Set existingKnowledgeIds, BigInteger knowledgeId) { + return knowledgeId != null && existingKnowledgeIds != null && !existingKnowledgeIds.add(knowledgeId); } } diff --git a/easyflow-modules/easyflow-module-ai/src/test/java/tech/easyflow/ai/service/impl/BotServiceImplTest.java b/easyflow-modules/easyflow-module-ai/src/test/java/tech/easyflow/ai/service/impl/BotServiceImplTest.java index 0a2d7dc..cb608c4 100644 --- a/easyflow-modules/easyflow-module-ai/src/test/java/tech/easyflow/ai/service/impl/BotServiceImplTest.java +++ b/easyflow-modules/easyflow-module-ai/src/test/java/tech/easyflow/ai/service/impl/BotServiceImplTest.java @@ -121,11 +121,12 @@ public class BotServiceImplTest { List.class, Bot.class, boolean.class, - ChatTimeToolAvailabilityContext.class + ChatTimeToolAvailabilityContext.class, + Set.class ); method.setAccessible(true); - method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3)); + method.invoke(service, functionList, runtimeBot, false, buildContext(12, 3), new LinkedHashSet<>()); Assert.assertEquals(1, functionList.size()); DocumentCollectionTool tool = (DocumentCollectionTool) functionList.get(0); @@ -133,6 +134,102 @@ public class BotServiceImplTest { Assert.assertEquals(RetrievalMode.KEYWORD, tool.getRetrievalMode()); } + /** + * 额外知识库应按用户选择顺序去重,并限制最多 3 个。 + * + * @throws Exception 反射调用异常 + */ + @Test + @SuppressWarnings("unchecked") + public void sanitizeExtraKnowledgeIdsShouldDeduplicateAndKeepOrder() throws Exception { + BotServiceImpl service = new BotServiceImpl(); + Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class); + method.setAccessible(true); + + List result = (List) method.invoke( + service, + List.of( + BigInteger.valueOf(3), + BigInteger.valueOf(1), + BigInteger.valueOf(3), + BigInteger.valueOf(2) + ) + ); + + Assert.assertEquals( + List.of(BigInteger.valueOf(3), BigInteger.ONE, BigInteger.valueOf(2)), + result + ); + } + + /** + * 额外知识库超过 3 个时应直接拒绝请求。 + * + * @throws Exception 反射调用异常 + */ + @Test + public void sanitizeExtraKnowledgeIdsShouldRejectWhenMoreThanThree() throws Exception { + BotServiceImpl service = new BotServiceImpl(); + Method method = BotServiceImpl.class.getDeclaredMethod("sanitizeExtraKnowledgeIds", List.class); + method.setAccessible(true); + + try { + method.invoke( + service, + List.of( + BigInteger.ONE, + BigInteger.valueOf(2), + BigInteger.valueOf(3), + BigInteger.valueOf(4) + ) + ); + Assert.fail("expected BusinessException"); + } catch (java.lang.reflect.InvocationTargetException exception) { + Assert.assertTrue(exception.getTargetException() instanceof tech.easyflow.common.web.exceptions.BusinessException); + Assert.assertEquals("额外知识库最多选择 3 个", exception.getTargetException().getMessage()); + } + } + + /** + * 额外知识库工具应强制使用发布态视图,并跳过重复选择。 + * + * @throws Exception 反射注入异常 + */ + @Test + public void appendExtraKnowledgeToolsShouldUsePublishedKnowledgeAndDeduplicate() throws Exception { + BotServiceImpl service = new BotServiceImpl(); + injectAvailabilityService(service, buildAvailabilityService( + new RoleCategoryAccessSnapshot("KNOWLEDGE", BigInteger.valueOf(12), false, false, setOf(BigInteger.valueOf(21))), + Collections.emptySet() + )); + + DocumentCollection draftKnowledge = buildKnowledge(101, 11, 21, "PUBLIC"); + draftKnowledge.setPublishStatus("PUBLISHED"); + draftKnowledge.setTitle("draft-title"); + draftKnowledge.setDescription("draft-desc"); + + DocumentCollection publishedKnowledge = buildKnowledge(101, 11, 21, "PUBLIC"); + publishedKnowledge.setPublishStatus("PUBLISHED"); + publishedKnowledge.setTitle("published-title"); + publishedKnowledge.setDescription("published-desc"); + + injectDocumentCollectionService(service, mockDocumentCollectionServiceForExtra(draftKnowledge, publishedKnowledge)); + + List tools = new ArrayList<>(); + service.appendExtraKnowledgeTools( + tools, + List.of(BigInteger.valueOf(101), BigInteger.valueOf(101)), + false, + buildContext(12, 3), + new LinkedHashSet<>() + ); + + Assert.assertEquals(1, tools.size()); + DocumentCollectionTool tool = (DocumentCollectionTool) tools.get(0); + Assert.assertEquals(BigInteger.valueOf(101), tool.getKnowledgeId()); + Assert.assertEquals("published-desc", tool.getDescription()); + } + private void injectAvailabilityService(BotServiceImpl service, ChatTimeToolAvailabilityService availabilityService) throws Exception { Field field = BotServiceImpl.class.getDeclaredField("chatTimeToolAvailabilityService"); field.setAccessible(true); @@ -189,6 +286,7 @@ public class BotServiceImplTest { knowledge.setCreatedBy(BigInteger.valueOf(createdBy)); knowledge.setCategoryId(BigInteger.valueOf(categoryId)); knowledge.setVisibilityScope(visibilityScope); + knowledge.setPublishStatus("PUBLISHED"); return knowledge; } @@ -216,6 +314,26 @@ public class BotServiceImplTest { ); } + private tech.easyflow.ai.service.DocumentCollectionService mockDocumentCollectionServiceForExtra(DocumentCollection draftCollection, + DocumentCollection publishedCollection) { + return (tech.easyflow.ai.service.DocumentCollectionService) Proxy.newProxyInstance( + tech.easyflow.ai.service.DocumentCollectionService.class.getClassLoader(), + new Class[]{tech.easyflow.ai.service.DocumentCollectionService.class}, + (proxy, method, args) -> { + if ("getById".equals(method.getName())) { + return draftCollection; + } + if ("toPublishedView".equals(method.getName())) { + return publishedCollection; + } + if ("getPublishedById".equals(method.getName())) { + return publishedCollection; + } + return defaultValue(method.getReturnType()); + } + ); + } + private CategoryPermissionService mockCategoryPermissionService(RoleCategoryAccessSnapshot accessSnapshot) { return (CategoryPermissionService) Proxy.newProxyInstance( CategoryPermissionService.class.getClassLoader(), diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/cache/ChatHotStateService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/cache/ChatHotStateService.java index c4d2cab..7885460 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/cache/ChatHotStateService.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/cache/ChatHotStateService.java @@ -7,9 +7,13 @@ import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.stereotype.Service; import tech.easyflow.chatlog.config.ChatCacheProperties; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; +import tech.easyflow.chatlog.support.ChatConstants; import java.io.IOException; import java.math.BigInteger; @@ -61,6 +65,9 @@ public class ChatHotStateService { if (command.getTitle() != null && !command.getTitle().isBlank()) { summary.setTitle(command.getTitle()); } + if (command.getExtJson() != null && !command.getExtJson().isBlank()) { + summary.setExtJson(command.getExtJson()); + } summary.setAccessAt(defaultDate(command.getOperateAt())); summary.setModified(defaultDate(command.getOperateAt())); summary.setModifiedBy(command.getOperatorId()); @@ -119,9 +126,78 @@ public class ChatHotStateService { summary.setModified(defaultDate(command.getCreated())); summary.setModifiedBy(command.getCreatedBy()); summary.setIsDeleted(0); - summary.setMessageCount((summary.getMessageCount() == null ? 0 : summary.getMessageCount()) + 1); + summary.setMessageCount((summary.getMessageCount() == null ? 0 : summary.getMessageCount()) + + resolveVisibleMessageIncrement(command)); cacheSessionSummaryStrict(summary); - appendTail(toMessageRecord(command)); + appendVisibleTail(toMessageRecord(command)); + } + + public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) { + if (command == null || command.getSessionId() == null || command.getRoundId() == null) { + return null; + } + ChatRoundRecord record = getRound(command.getSessionId(), command.getRoundId()); + if (record == null) { + record = new ChatRoundRecord(); + record.setId(command.getRoundId()); + record.setSessionId(command.getSessionId()); + record.setCreated(defaultDate(command.getOperateAt())); + } + if (command.getRoundNo() != null) { + record.setRoundNo(command.getRoundNo()); + } + if (command.getUserMessageId() != null) { + record.setUserMessageId(command.getUserMessageId()); + } + if (command.getSelectedAssistantMessageId() != null) { + record.setSelectedAssistantMessageId(command.getSelectedAssistantMessageId()); + } + if (command.getSelectedVariantIndex() != null) { + record.setSelectedVariantIndex(command.getSelectedVariantIndex()); + } + if (command.getVariantCount() != null) { + record.setVariantCount(command.getVariantCount()); + } + if (command.getStatus() != null && !command.getStatus().isBlank()) { + record.setStatus(command.getStatus()); + } + record.setModified(defaultDate(command.getOperateAt())); + cacheRoundStrict(record); + syncTailRoundMeta(record); + return record; + } + + public void selectVariant(ChatRoundSelectCommand command) { + if (command == null || command.getSessionId() == null || command.getRoundId() == null) { + return; + } + ChatRoundRecord record = getRound(command.getSessionId(), command.getRoundId()); + if (record == null) { + return; + } + record.setSelectedAssistantMessageId(command.getSelectedAssistantMessageId()); + record.setSelectedVariantIndex(command.getSelectedVariantIndex()); + record.setModified(defaultDate(command.getOperateAt())); + cacheRoundStrict(record); + if (command.getSelectedAssistantMessage() != null) { + applyRoundMeta(command.getSelectedAssistantMessage(), record); + replaceSelectedAssistant(record, command.getSelectedAssistantMessage()); + } + } + + public ChatRoundRecord getLatestRound(BigInteger sessionId) { + return readValue(keyLatestRound(sessionId), ChatRoundRecord.class); + } + + public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) { + return readValue(keyRound(sessionId, roundId), ChatRoundRecord.class); + } + + public void cacheRound(ChatRoundRecord record) { + try { + cacheRoundStrict(record); + } catch (IllegalStateException ignored) { + } } public List listSessionIds(BigInteger userId, long offset, long limit) { @@ -263,18 +339,168 @@ public class ChatHotStateService { } public void appendTail(ChatMessageRecord record) { + appendVisibleTail(record); + } + + public void appendVisibleTail(ChatMessageRecord record) { if (record == null || record.getSessionId() == null) { return; } + ChatRoundRecord round = record.getRoundId() == null ? null : getRound(record.getSessionId(), record.getRoundId()); + applyRoundMeta(record, round); List current = getSessionTail(record.getSessionId()); - List updated = new ArrayList<>(); - updated.add(record); - if (current != null) { - updated.addAll(current); + List updated = current == null ? new ArrayList<>() : new ArrayList<>(current); + removeExistingVisibleRecord(updated, record); + insertSorted(updated, record); + if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(record.getMessageKind()) + && record.getVariantIndex() != null + && record.getVariantIndex() > 1) { + removeOlderSelectedAssistant(updated, record); } writeValueStrict(keySessionTail(record.getSessionId()), trimTail(updated), properties.getSessionTailTtl()); } + private void replaceSelectedAssistant(ChatRoundRecord round, ChatMessageRecord record) { + List current = getSessionTail(record.getSessionId()); + List updated = current == null ? new ArrayList<>() : new ArrayList<>(current); + syncRoundMeta(updated, round); + removeExistingVisibleRecord(updated, record); + removeOlderSelectedAssistant(updated, record); + insertSorted(updated, record); + writeValueStrict(keySessionTail(record.getSessionId()), trimTail(updated), properties.getSessionTailTtl()); + } + + /** + * 将轮次选中态同步到 Redis tail,避免前端主线过滤读到过期版本号。 + * + * @param round 轮次记录 + */ + private void syncTailRoundMeta(ChatRoundRecord round) { + if (round == null || round.getSessionId() == null || round.getId() == null) { + return; + } + List current = getSessionTail(round.getSessionId()); + if (current == null || current.isEmpty()) { + return; + } + List updated = new ArrayList<>(current); + syncRoundMeta(updated, round); + if (round.getSelectedAssistantMessageId() != null) { + updated.removeIf(item -> item != null + && item.getRoundId() != null + && item.getRoundId().equals(round.getId()) + && "assistant".equalsIgnoreCase(item.getSenderRole()) + && item.getId() != null + && !item.getId().equals(round.getSelectedAssistantMessageId())); + } + writeValueStrict(keySessionTail(round.getSessionId()), trimTail(updated), properties.getSessionTailTtl()); + } + + /** + * 批量同步同一轮次的版本元信息。 + * + * @param records tail 消息列表 + * @param round 轮次记录 + */ + private void syncRoundMeta(List records, ChatRoundRecord round) { + if (records == null || records.isEmpty() || round == null || round.getId() == null) { + return; + } + for (ChatMessageRecord item : records) { + if (item != null && round.getId().equals(item.getRoundId())) { + applyRoundMeta(item, round); + } + } + } + + /** + * 将轮次元信息写入单条消息。 + * + * @param record 消息记录 + * @param round 轮次记录 + */ + private void applyRoundMeta(ChatMessageRecord record, ChatRoundRecord round) { + if (record == null || round == null) { + return; + } + if (round.getRoundNo() != null) { + record.setRoundNo(round.getRoundNo()); + } + if (round.getVariantCount() != null) { + record.setVariantCount(round.getVariantCount()); + } + if (round.getSelectedVariantIndex() != null) { + record.setSelectedVariantIndex(round.getSelectedVariantIndex()); + } + if (round.getStatus() != null) { + record.setSwitchable(!ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus())); + } + } + + private void removeOlderSelectedAssistant(List records, ChatMessageRecord record) { + if (record.getRoundId() == null) { + return; + } + records.removeIf(item -> item != null + && item.getId() != null + && !item.getId().equals(record.getId()) + && item.getRoundId() != null + && item.getRoundId().equals(record.getRoundId()) + && "assistant".equalsIgnoreCase(item.getSenderRole())); + } + + private void removeExistingVisibleRecord(List records, ChatMessageRecord record) { + if (records == null || record == null || record.getId() == null) { + return; + } + records.removeIf(item -> item != null && record.getId().equals(item.getId())); + } + + private void insertSorted(List records, ChatMessageRecord record) { + int insertIndex = 0; + while (insertIndex < records.size()) { + ChatMessageRecord current = records.get(insertIndex); + if (shouldComeAfter(record, current)) { + insertIndex++; + continue; + } + break; + } + records.add(insertIndex, record); + } + + private boolean shouldComeAfter(ChatMessageRecord candidate, ChatMessageRecord current) { + if (candidate == null) { + return true; + } + if (current == null) { + return false; + } + Date candidateCreated = defaultDate(candidate.getCreated()); + Date currentCreated = defaultDate(current.getCreated()); + if (candidateCreated.before(currentCreated)) { + return true; + } + if (candidateCreated.after(currentCreated)) { + return false; + } + BigInteger candidateId = candidate.getId() == null ? BigInteger.ZERO : candidate.getId(); + BigInteger currentId = current.getId() == null ? BigInteger.ZERO : current.getId(); + return candidateId.compareTo(currentId) < 0; + } + + private void cacheRoundStrict(ChatRoundRecord record) { + if (record == null || record.getSessionId() == null || record.getId() == null) { + return; + } + writeValueStrict(keyRound(record.getSessionId(), record.getId()), record, properties.getSessionTailTtl()); + ChatRoundRecord latest = getLatestRound(record.getSessionId()); + if (latest == null || latest.getRoundNo() == null + || (record.getRoundNo() != null && record.getRoundNo() >= latest.getRoundNo())) { + writeValueStrict(keyLatestRound(record.getSessionId()), record, properties.getSessionTailTtl()); + } + } + public void evictSessionTail(BigInteger sessionId) { delete(keySessionTail(sessionId)); } @@ -292,6 +518,10 @@ public class ChatHotStateService { record.setSenderId(command.getSenderId()); record.setSenderName(command.getSenderName()); record.setSenderRole(command.getSenderRole()); + record.setRoundId(command.getRoundId()); + record.setRoundNo(command.getRoundNo()); + record.setMessageKind(command.getMessageKind()); + record.setVariantIndex(command.getVariantIndex()); record.setContentType(command.getContentType()); record.setContentText(command.getContentText()); record.setContentPayload(command.getContentPayload()); @@ -301,6 +531,17 @@ public class ChatHotStateService { return record; } + private int resolveVisibleMessageIncrement(ChatAppendMessageCommand command) { + if (command == null) { + return 0; + } + if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(command.getMessageKind())) { + Integer variantIndex = command.getVariantIndex(); + return variantIndex != null && variantIndex > 1 ? 0 : 1; + } + return 1; + } + private List trimTail(List records) { if (records == null || records.isEmpty()) { return Collections.emptyList(); @@ -326,6 +567,14 @@ public class ChatHotStateService { return "chat:session:tail:" + sessionId; } + private String keyLatestRound(BigInteger sessionId) { + return "chat:session:round:latest:" + sessionId; + } + + private String keyRound(BigInteger sessionId, BigInteger roundId) { + return "chat:session:round:" + sessionId + ":" + roundId; + } + private void removeFromSessionIndex(BigInteger userId, BigInteger sessionId) { if (userId == null || sessionId == null) { return; diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatAppendMessageCommand.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatAppendMessageCommand.java index 0c8e63f..d96a02b 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatAppendMessageCommand.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatAppendMessageCommand.java @@ -23,6 +23,10 @@ public class ChatAppendMessageCommand implements Serializable { private String contentType; private String contentText; private Map contentPayload; + private BigInteger roundId; + private Integer roundNo; + private String messageKind; + private Integer variantIndex; private BigInteger createdBy; private Date created = new Date(); @@ -154,6 +158,38 @@ public class ChatAppendMessageCommand implements Serializable { this.contentPayload = contentPayload; } + public BigInteger getRoundId() { + return roundId; + } + + public void setRoundId(BigInteger roundId) { + this.roundId = roundId; + } + + public Integer getRoundNo() { + return roundNo; + } + + public void setRoundNo(Integer roundNo) { + this.roundNo = roundNo; + } + + public String getMessageKind() { + return messageKind; + } + + public void setMessageKind(String messageKind) { + this.messageKind = messageKind; + } + + public Integer getVariantIndex() { + return variantIndex; + } + + public void setVariantIndex(Integer variantIndex) { + this.variantIndex = variantIndex; + } + public BigInteger getCreatedBy() { return createdBy; } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundSelectCommand.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundSelectCommand.java new file mode 100644 index 0000000..95a9d66 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundSelectCommand.java @@ -0,0 +1,77 @@ +package tech.easyflow.chatlog.domain.command; + +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; + +import java.io.Serializable; +import java.math.BigInteger; +import java.util.Date; + +/** + * 轮次答案版本切换命令。 + */ +public class ChatRoundSelectCommand implements Serializable { + + private BigInteger sessionId; + private BigInteger roundId; + private Integer selectedVariantIndex; + private BigInteger selectedAssistantMessageId; + private ChatMessageRecord selectedAssistantMessage; + private BigInteger operatorId; + private Date operateAt = new Date(); + + public BigInteger getSessionId() { + return sessionId; + } + + public void setSessionId(BigInteger sessionId) { + this.sessionId = sessionId; + } + + public BigInteger getRoundId() { + return roundId; + } + + public void setRoundId(BigInteger roundId) { + this.roundId = roundId; + } + + public Integer getSelectedVariantIndex() { + return selectedVariantIndex; + } + + public void setSelectedVariantIndex(Integer selectedVariantIndex) { + this.selectedVariantIndex = selectedVariantIndex; + } + + public BigInteger getSelectedAssistantMessageId() { + return selectedAssistantMessageId; + } + + public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) { + this.selectedAssistantMessageId = selectedAssistantMessageId; + } + + public ChatMessageRecord getSelectedAssistantMessage() { + return selectedAssistantMessage; + } + + public void setSelectedAssistantMessage(ChatMessageRecord selectedAssistantMessage) { + this.selectedAssistantMessage = selectedAssistantMessage; + } + + public BigInteger getOperatorId() { + return operatorId; + } + + public void setOperatorId(BigInteger operatorId) { + this.operatorId = operatorId; + } + + public Date getOperateAt() { + return operateAt; + } + + public void setOperateAt(Date operateAt) { + this.operateAt = operateAt; + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundUpsertCommand.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundUpsertCommand.java new file mode 100644 index 0000000..6b15b5c --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatRoundUpsertCommand.java @@ -0,0 +1,102 @@ +package tech.easyflow.chatlog.domain.command; + +import java.io.Serializable; +import java.math.BigInteger; +import java.util.Date; + +/** + * 轮次聚合写入命令。 + */ +public class ChatRoundUpsertCommand implements Serializable { + + private BigInteger roundId; + private BigInteger sessionId; + private Integer roundNo; + private BigInteger userMessageId; + private BigInteger selectedAssistantMessageId; + private Integer selectedVariantIndex; + private Integer variantCount; + private String status; + private BigInteger operatorId; + private Date operateAt = new Date(); + + public BigInteger getRoundId() { + return roundId; + } + + public void setRoundId(BigInteger roundId) { + this.roundId = roundId; + } + + public BigInteger getSessionId() { + return sessionId; + } + + public void setSessionId(BigInteger sessionId) { + this.sessionId = sessionId; + } + + public Integer getRoundNo() { + return roundNo; + } + + public void setRoundNo(Integer roundNo) { + this.roundNo = roundNo; + } + + public BigInteger getUserMessageId() { + return userMessageId; + } + + public void setUserMessageId(BigInteger userMessageId) { + this.userMessageId = userMessageId; + } + + public BigInteger getSelectedAssistantMessageId() { + return selectedAssistantMessageId; + } + + public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) { + this.selectedAssistantMessageId = selectedAssistantMessageId; + } + + public Integer getSelectedVariantIndex() { + return selectedVariantIndex; + } + + public void setSelectedVariantIndex(Integer selectedVariantIndex) { + this.selectedVariantIndex = selectedVariantIndex; + } + + public Integer getVariantCount() { + return variantCount; + } + + public void setVariantCount(Integer variantCount) { + this.variantCount = variantCount; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public BigInteger getOperatorId() { + return operatorId; + } + + public void setOperatorId(BigInteger operatorId) { + this.operatorId = operatorId; + } + + public Date getOperateAt() { + return operateAt; + } + + public void setOperateAt(Date operateAt) { + this.operateAt = operateAt; + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionSummaryCommand.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionSummaryCommand.java index f36c61e..1462bbd 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionSummaryCommand.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionSummaryCommand.java @@ -12,8 +12,11 @@ public class ChatSessionSummaryCommand implements Serializable { private String lastSenderName; private String lastMessagePreview; private Date lastMessageAt = new Date(); + private Date accessAt = new Date(); + private Date modifiedAt = new Date(); private BigInteger operatorId; private int messageIncrement = 1; + private boolean forceOverwrite; public BigInteger getSessionId() { return sessionId; @@ -63,6 +66,22 @@ public class ChatSessionSummaryCommand implements Serializable { this.lastMessageAt = lastMessageAt; } + public Date getAccessAt() { + return accessAt; + } + + public void setAccessAt(Date accessAt) { + this.accessAt = accessAt; + } + + public Date getModifiedAt() { + return modifiedAt; + } + + public void setModifiedAt(Date modifiedAt) { + this.modifiedAt = modifiedAt; + } + public BigInteger getOperatorId() { return operatorId; } @@ -78,4 +97,12 @@ public class ChatSessionSummaryCommand implements Serializable { public void setMessageIncrement(int messageIncrement) { this.messageIncrement = Math.max(messageIncrement, 0); } + + public boolean isForceOverwrite() { + return forceOverwrite; + } + + public void setForceOverwrite(boolean forceOverwrite) { + this.forceOverwrite = forceOverwrite; + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionUpsertCommand.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionUpsertCommand.java index dfb0c96..6641f6d 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionUpsertCommand.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/command/ChatSessionUpsertCommand.java @@ -15,6 +15,7 @@ public class ChatSessionUpsertCommand implements Serializable { private String assistantCode; private String assistantName; private String title; + private String extJson; private BigInteger operatorId; private Date operateAt = new Date(); @@ -90,6 +91,14 @@ public class ChatSessionUpsertCommand implements Serializable { this.title = title; } + public String getExtJson() { + return extJson; + } + + public void setExtJson(String extJson) { + this.extJson = extJson; + } + public BigInteger getOperatorId() { return operatorId; } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatMessageRecord.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatMessageRecord.java index 7fc95d7..fe051e0 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatMessageRecord.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatMessageRecord.java @@ -17,6 +17,13 @@ public class ChatMessageRecord implements Serializable { private String contentType; private String contentText; private Map contentPayload; + private BigInteger roundId; + private Integer roundNo; + private String messageKind; + private Integer variantIndex; + private Integer variantCount; + private Integer selectedVariantIndex; + private Boolean switchable; private Date created; private BigInteger createdBy; private Long syncVersion; @@ -101,6 +108,62 @@ public class ChatMessageRecord implements Serializable { this.contentPayload = contentPayload; } + public BigInteger getRoundId() { + return roundId; + } + + public void setRoundId(BigInteger roundId) { + this.roundId = roundId; + } + + public Integer getRoundNo() { + return roundNo; + } + + public void setRoundNo(Integer roundNo) { + this.roundNo = roundNo; + } + + public String getMessageKind() { + return messageKind; + } + + public void setMessageKind(String messageKind) { + this.messageKind = messageKind; + } + + public Integer getVariantIndex() { + return variantIndex; + } + + public void setVariantIndex(Integer variantIndex) { + this.variantIndex = variantIndex; + } + + public Integer getVariantCount() { + return variantCount; + } + + public void setVariantCount(Integer variantCount) { + this.variantCount = variantCount; + } + + public Integer getSelectedVariantIndex() { + return selectedVariantIndex; + } + + public void setSelectedVariantIndex(Integer selectedVariantIndex) { + this.selectedVariantIndex = selectedVariantIndex; + } + + public Boolean getSwitchable() { + return switchable; + } + + public void setSwitchable(Boolean switchable) { + this.switchable = switchable; + } + public Date getCreated() { return created; } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatRoundRecord.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatRoundRecord.java new file mode 100644 index 0000000..e5f6999 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatRoundRecord.java @@ -0,0 +1,102 @@ +package tech.easyflow.chatlog.domain.dto; + +import java.io.Serializable; +import java.math.BigInteger; +import java.util.Date; + +/** + * 聊天轮次记录。 + */ +public class ChatRoundRecord implements Serializable { + + private BigInteger id; + private BigInteger sessionId; + private Integer roundNo; + private BigInteger userMessageId; + private BigInteger selectedAssistantMessageId; + private Integer selectedVariantIndex; + private Integer variantCount; + private String status; + private Date created; + private Date modified; + + public BigInteger getId() { + return id; + } + + public void setId(BigInteger id) { + this.id = id; + } + + public BigInteger getSessionId() { + return sessionId; + } + + public void setSessionId(BigInteger sessionId) { + this.sessionId = sessionId; + } + + public Integer getRoundNo() { + return roundNo; + } + + public void setRoundNo(Integer roundNo) { + this.roundNo = roundNo; + } + + public BigInteger getUserMessageId() { + return userMessageId; + } + + public void setUserMessageId(BigInteger userMessageId) { + this.userMessageId = userMessageId; + } + + public BigInteger getSelectedAssistantMessageId() { + return selectedAssistantMessageId; + } + + public void setSelectedAssistantMessageId(BigInteger selectedAssistantMessageId) { + this.selectedAssistantMessageId = selectedAssistantMessageId; + } + + public Integer getSelectedVariantIndex() { + return selectedVariantIndex; + } + + public void setSelectedVariantIndex(Integer selectedVariantIndex) { + this.selectedVariantIndex = selectedVariantIndex; + } + + public Integer getVariantCount() { + return variantCount; + } + + public void setVariantCount(Integer variantCount) { + this.variantCount = variantCount; + } + + public String getStatus() { + return status; + } + + public void setStatus(String status) { + this.status = status; + } + + public Date getCreated() { + return created; + } + + public void setCreated(Date created) { + this.created = created; + } + + public Date getModified() { + return modified; + } + + public void setModified(Date modified) { + this.modified = modified; + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionExtPayload.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionExtPayload.java new file mode 100644 index 0000000..5cd6cda --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionExtPayload.java @@ -0,0 +1,32 @@ +package tech.easyflow.chatlog.domain.dto; + +import java.io.Serializable; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; + +/** + * 会话扩展载荷。 + */ +public class ChatSessionExtPayload implements Serializable { + + private List extraKnowledgeIds = new ArrayList<>(); + + /** + * 获取会话级额外知识库 ID 列表。 + * + * @return 额外知识库 ID 列表 + */ + public List getExtraKnowledgeIds() { + return extraKnowledgeIds; + } + + /** + * 设置会话级额外知识库 ID 列表。 + * + * @param extraKnowledgeIds 额外知识库 ID 列表 + */ + public void setExtraKnowledgeIds(List extraKnowledgeIds) { + this.extraKnowledgeIds = extraKnowledgeIds == null ? new ArrayList<>() : new ArrayList<>(extraKnowledgeIds); + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionSummary.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionSummary.java index 5eba1f1..6a8aff1 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionSummary.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/dto/ChatSessionSummary.java @@ -15,6 +15,7 @@ public class ChatSessionSummary implements Serializable { private String assistantCode; private String assistantName; private String title; + private String extJson; private String lastMessagePreview; private BigInteger lastSenderId; private String lastSenderName; @@ -99,6 +100,14 @@ public class ChatSessionSummary implements Serializable { this.title = title; } + public String getExtJson() { + return extJson; + } + + public void setExtJson(String extJson) { + this.extJson = extJson; + } + public String getLastMessagePreview() { return lastMessagePreview; } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/event/ChatPersistEventType.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/event/ChatPersistEventType.java index 812089b..0436376 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/event/ChatPersistEventType.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/domain/event/ChatPersistEventType.java @@ -3,6 +3,8 @@ package tech.easyflow.chatlog.domain.event; public enum ChatPersistEventType { SESSION_PREPARED, + ROUND_UPSERTED, + ROUND_VARIANT_SELECTED, USER_MESSAGE_APPENDED, ASSISTANT_MESSAGE_APPENDED, SESSION_RENAMED, diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatLogRepository.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatLogRepository.java index b925cd1..0f9e184 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatLogRepository.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatLogRepository.java @@ -5,6 +5,7 @@ import org.springframework.stereotype.Repository; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; import tech.easyflow.chatlog.support.ChatJsonSupport; +import tech.easyflow.chatlog.support.ChatConstants; import tech.easyflow.chatlog.support.ChatTableRouter; import java.math.BigInteger; @@ -49,8 +50,8 @@ public class MySqlChatLogRepository { for (Map.Entry> entry : grouped.entrySet()) { String table = tableRouter.resolveLogTable(entry.getKey()); String sql = "INSERT IGNORE INTO `" + table + "` " + - "(id, tenant_id, dept_id, session_id, user_id, assistant_id, sender_id, sender_name, sender_role, content_type, content_text, content_payload, created, created_by, sync_version) " + - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + "(id, tenant_id, dept_id, session_id, user_id, assistant_id, round_id, round_no, sender_id, sender_name, sender_role, message_kind, variant_index, content_type, content_text, content_payload, created, created_by, sync_version) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; int[] results = jdbcTemplate.batchUpdate(sql, new org.springframework.jdbc.core.BatchPreparedStatementSetter() { @Override public void setValues(java.sql.PreparedStatement ps, int i) throws java.sql.SQLException { @@ -62,15 +63,19 @@ public class MySqlChatLogRepository { ps.setObject(4, command.getSessionId()); ps.setObject(5, command.getUserId()); ps.setObject(6, command.getAssistantId()); - ps.setObject(7, command.getSenderId()); - ps.setString(8, command.getSenderName()); - ps.setString(9, command.getSenderRole()); - ps.setString(10, command.getContentType()); - ps.setString(11, command.getContentText()); - ps.setString(12, jsonSupport.toJson(command.getContentPayload())); - ps.setTimestamp(13, created); - ps.setObject(14, command.getCreatedBy()); - ps.setLong(15, command.getCreated().getTime()); + ps.setObject(7, command.getRoundId()); + ps.setObject(8, command.getRoundNo()); + ps.setObject(9, command.getSenderId()); + ps.setString(10, command.getSenderName()); + ps.setString(11, command.getSenderRole()); + ps.setString(12, command.getMessageKind()); + ps.setObject(13, command.getVariantIndex()); + ps.setString(14, command.getContentType()); + ps.setString(15, command.getContentText()); + ps.setString(16, jsonSupport.toJson(command.getContentPayload())); + ps.setTimestamp(17, created); + ps.setObject(18, command.getCreatedBy()); + ps.setLong(19, command.getCreated().getTime()); } @Override @@ -92,7 +97,14 @@ public class MySqlChatLogRepository { for (YearMonth month : months) { String table = tableRouter.resolveLogTable(month); List current = jdbcTemplate.query( - "SELECT * FROM `" + table + "` WHERE session_id=? ORDER BY created DESC, id DESC LIMIT ?", + "SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " + + "CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " + + "AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " + + "THEN 1 ELSE 0 END AS switchable " + + "FROM `" + table + "` l " + + "LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " + + "WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " + + "ORDER BY l.created DESC, l.id DESC LIMIT ?", (rs, rowNum) -> mapRow(rs), sessionId, limit @@ -114,6 +126,159 @@ public class MySqlChatLogRepository { .collect(Collectors.toList()); } + /** + * 从 MySQL 热表分页查询主线可见消息。 + * + * @param sessionId 会话 ID + * @param months 查询月份 + * @param offset 分页偏移 + * @param limit 分页条数 + * @return 主线消息列表,按 created desc、id desc 排序 + */ + public List listMainlineMessages(BigInteger sessionId, List months, long offset, int limit) { + if (sessionId == null || months == null || months.isEmpty() || limit <= 0) { + return Collections.emptyList(); + } + int candidateLimit = resolveCandidateLimit(offset, limit); + Map recordMap = new LinkedHashMap<>(); + for (YearMonth month : months) { + String table = tableRouter.resolveLogTable(month); + List current = jdbcTemplate.query( + "SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " + + "CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " + + "AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " + + "THEN 1 ELSE 0 END AS switchable " + + "FROM `" + table + "` l " + + "LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " + + "WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " + + "ORDER BY l.created DESC, l.id DESC LIMIT ?", + (rs, rowNum) -> mapRow(rs), + sessionId, + candidateLimit + ); + for (ChatMessageRecord record : current) { + if (record != null && record.getId() != null) { + recordMap.putIfAbsent(record.getId(), record); + } + } + } + return recordMap.values().stream() + .sorted((a, b) -> { + int compare = b.getCreated().compareTo(a.getCreated()); + if (compare != 0) { + return compare; + } + return b.getId().compareTo(a.getId()); + }) + .skip(Math.max(offset, 0)) + .limit(limit) + .collect(Collectors.toList()); + } + + /** + * 从 MySQL 热表查询全部主线可见消息。 + * + * @param sessionId 会话 ID + * @param months 查询月份 + * @return 主线消息列表,按 created asc、id asc 排序 + */ + public List listMainlineMessages(BigInteger sessionId, List months) { + if (sessionId == null || months == null || months.isEmpty()) { + return Collections.emptyList(); + } + Map recordMap = new LinkedHashMap<>(); + for (YearMonth month : months) { + String table = tableRouter.resolveLogTable(month); + List current = jdbcTemplate.query( + "SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " + + "CASE WHEN r.status IS NOT NULL AND r.status <> 'LOCKED' " + + "AND NOT EXISTS (SELECT 1 FROM `" + ChatConstants.ROUND_TABLE + "` newer WHERE newer.session_id = r.session_id AND newer.round_no > r.round_no) " + + "THEN 1 ELSE 0 END AS switchable " + + "FROM `" + table + "` l " + + "LEFT JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " + + "WHERE l.session_id=? AND (l.round_id IS NULL OR r.id IS NULL OR l.id = r.user_message_id OR l.id = r.selected_assistant_message_id) " + + "ORDER BY l.created ASC, l.id ASC", + (rs, rowNum) -> mapRow(rs), + sessionId + ); + for (ChatMessageRecord record : current) { + if (record != null && record.getId() != null) { + recordMap.putIfAbsent(record.getId(), record); + } + } + } + return recordMap.values().stream() + .sorted((a, b) -> { + int compare = a.getCreated().compareTo(b.getCreated()); + if (compare != 0) { + return compare; + } + return a.getId().compareTo(b.getId()); + }) + .collect(Collectors.toList()); + } + + public List listRoundVariants(BigInteger sessionId, BigInteger roundId, List months) { + List records = new ArrayList<>(); + for (YearMonth month : months) { + String table = tableRouter.resolveLogTable(month); + records.addAll(jdbcTemplate.query( + "SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, " + + "0 AS switchable " + + "FROM `" + table + "` l " + + "INNER JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " + + "WHERE l.session_id=? AND l.round_id=? AND l.message_kind=? " + + "ORDER BY l.variant_index ASC, l.created ASC, l.id ASC", + (rs, rowNum) -> mapRow(rs), + sessionId, + roundId, + ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT + )); + } + return records; + } + + /** + * 精准查询轮次下指定答案版本。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @param variantIndex 答案版本序号 + * @param months 查询月份 + * @return 目标答案版本 + */ + public ChatMessageRecord findRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, List months) { + if (sessionId == null || roundId == null || variantIndex == null || variantIndex <= 0 || months == null || months.isEmpty()) { + return null; + } + for (YearMonth month : months) { + String table = tableRouter.resolveLogTable(month); + List records = jdbcTemplate.query( + "SELECT l.*, r.round_no AS joined_round_no, r.variant_count, r.selected_variant_index, 0 AS switchable " + + "FROM `" + table + "` l " + + "INNER JOIN `" + ChatConstants.ROUND_TABLE + "` r ON l.round_id = r.id " + + "WHERE l.session_id=? AND l.round_id=? AND l.message_kind=? AND l.variant_index=? " + + "ORDER BY l.created DESC, l.id DESC LIMIT 1", + (rs, rowNum) -> mapRow(rs), + sessionId, + roundId, + ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT, + variantIndex + ); + if (!records.isEmpty()) { + return records.get(0); + } + } + return null; + } + + private int resolveCandidateLimit(long offset, int limit) { + long normalizedOffset = Math.max(offset, 0); + long normalizedLimit = Math.max(limit, 1); + long candidateLimit = normalizedOffset + normalizedLimit; + return (int) Math.min(candidateLimit, Integer.MAX_VALUE); + } + public List loadIncremental(String table, Date cursorTime, BigInteger cursorId, int limit) { Timestamp timestamp = cursorTime == null ? new Timestamp(0L) : new Timestamp(cursorTime.getTime()); return jdbcTemplate.query( @@ -149,9 +314,20 @@ public class MySqlChatLogRepository { record.setSessionId(bigInteger(rs, "session_id")); record.setUserId(bigInteger(rs, "user_id")); record.setAssistantId(bigInteger(rs, "assistant_id")); + record.setRoundId(bigInteger(rs, "round_id")); + record.setRoundNo(optionalInteger(rs, "round_no")); + Integer joinedRoundNo = optionalInteger(rs, "joined_round_no"); + if (joinedRoundNo != null) { + record.setRoundNo(joinedRoundNo); + } record.setSenderId(bigInteger(rs, "sender_id")); record.setSenderName(rs.getString("sender_name")); record.setSenderRole(rs.getString("sender_role")); + record.setMessageKind(optionalString(rs, "message_kind")); + record.setVariantIndex(optionalInteger(rs, "variant_index")); + record.setVariantCount(optionalInteger(rs, "variant_count")); + record.setSelectedVariantIndex(optionalInteger(rs, "selected_variant_index")); + record.setSwitchable(optionalBoolean(rs, "switchable")); record.setContentType(rs.getString("content_type")); record.setContentText(rs.getString("content_text")); record.setContentPayload(jsonSupport.toMap(rs.getString("content_payload"))); @@ -168,4 +344,39 @@ public class MySqlChatLogRepository { } return new BigInteger(String.valueOf(value)); } + + private Integer optionalInteger(ResultSet rs, String column) { + try { + Object value = rs.getObject(column); + if (value == null) { + return null; + } + return Integer.parseInt(String.valueOf(value)); + } catch (SQLException ignored) { + return null; + } + } + + private Boolean optionalBoolean(ResultSet rs, String column) { + try { + Object value = rs.getObject(column); + if (value == null) { + return null; + } + if (value instanceof Boolean booleanValue) { + return booleanValue; + } + return Integer.parseInt(String.valueOf(value)) != 0; + } catch (SQLException ignored) { + return null; + } + } + + private String optionalString(ResultSet rs, String column) { + try { + return rs.getString(column); + } catch (SQLException ignored) { + return null; + } + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatRoundRepository.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatRoundRepository.java new file mode 100644 index 0000000..6b97b79 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatRoundRepository.java @@ -0,0 +1,196 @@ +package tech.easyflow.chatlog.repository.mysql; + +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.stereotype.Repository; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.support.ChatConstants; + +import java.math.BigInteger; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +/** + * MySQL 轮次仓储。 + */ +@Repository +public class MySqlChatRoundRepository { + + private final JdbcTemplate jdbcTemplate; + + public MySqlChatRoundRepository(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + } + + /** + * 批量写入或更新轮次聚合。 + * + * @param commands 轮次命令 + */ + public void createOrTouchBatch(List commands) { + if (commands == null || commands.isEmpty()) { + return; + } + String sql = "INSERT INTO `" + ChatConstants.ROUND_TABLE + "` " + + "(id, session_id, round_no, user_message_id, selected_assistant_message_id, selected_variant_index, variant_count, status, created, modified) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) " + + "ON DUPLICATE KEY UPDATE user_message_id=COALESCE(VALUES(user_message_id), user_message_id), " + + "selected_assistant_message_id=CASE " + + "WHEN VALUES(selected_assistant_message_id) IS NOT NULL AND (selected_assistant_message_id IS NULL OR VALUES(modified) > modified) " + + "THEN VALUES(selected_assistant_message_id) ELSE selected_assistant_message_id END, " + + "selected_variant_index=CASE " + + "WHEN VALUES(selected_variant_index) IS NOT NULL AND (selected_variant_index IS NULL OR selected_variant_index = 0 OR VALUES(modified) > modified) " + + "THEN VALUES(selected_variant_index) ELSE selected_variant_index END, " + + "variant_count=GREATEST(COALESCE(VALUES(variant_count), 0), COALESCE(variant_count, 0)), " + + "status=CASE " + + "WHEN VALUES(status) IS NULL OR VALUES(status) = '' THEN status " + + "WHEN VALUES(modified) > modified THEN VALUES(status) " + + "WHEN VALUES(modified) = modified AND status = '" + ChatConstants.ROUND_STATUS_ANSWERING + "' " + + "AND VALUES(status) <> '" + ChatConstants.ROUND_STATUS_ANSWERING + "' THEN VALUES(status) " + + "WHEN VALUES(modified) = modified AND status <> '" + ChatConstants.ROUND_STATUS_LOCKED + "' " + + "AND VALUES(status) = '" + ChatConstants.ROUND_STATUS_LOCKED + "' THEN VALUES(status) " + + "ELSE status END, " + + "modified=GREATEST(VALUES(modified), modified)"; + jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> { + Timestamp operateAt = timestamp(command.getOperateAt()); + ps.setObject(1, command.getRoundId()); + ps.setObject(2, command.getSessionId()); + ps.setObject(3, command.getRoundNo()); + ps.setObject(4, command.getUserMessageId()); + ps.setObject(5, command.getSelectedAssistantMessageId()); + ps.setObject(6, command.getSelectedVariantIndex()); + ps.setObject(7, command.getVariantCount()); + ps.setString(8, command.getStatus()); + ps.setTimestamp(9, operateAt); + ps.setTimestamp(10, operateAt); + }); + } + + /** + * 应用答案版本切换。 + * + * @param commands 切换命令 + */ + public void selectVariants(List commands) { + if (commands == null || commands.isEmpty()) { + return; + } + String sql = "UPDATE `" + ChatConstants.ROUND_TABLE + "` SET selected_assistant_message_id=?, selected_variant_index=?, modified=? " + + "WHERE id=? AND session_id=? AND modified <= ?"; + jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> { + Timestamp operateAt = timestamp(command.getOperateAt()); + ps.setObject(1, command.getSelectedAssistantMessageId()); + ps.setObject(2, command.getSelectedVariantIndex()); + ps.setTimestamp(3, operateAt); + ps.setObject(4, command.getRoundId()); + ps.setObject(5, command.getSessionId()); + ps.setTimestamp(6, operateAt); + }); + } + + /** + * 查询指定轮次。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @return 轮次记录 + */ + public ChatRoundRecord findRound(BigInteger sessionId, BigInteger roundId) { + List records = jdbcTemplate.query( + "SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=? AND id=? LIMIT 1", + rowMapper(), + sessionId, + roundId + ); + return records.isEmpty() ? null : records.get(0); + } + + /** + * 查询会话最新轮次。 + * + * @param sessionId 会话 ID + * @return 最新轮次 + */ + public ChatRoundRecord findLatestRound(BigInteger sessionId) { + List records = jdbcTemplate.query( + "SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=? ORDER BY round_no DESC, id DESC LIMIT 1", + rowMapper(), + sessionId + ); + return records.isEmpty() ? null : records.get(0); + } + + /** + * 判断会话是否已使用轮次模型。 + * + * @param sessionId 会话 ID + * @return 是否存在轮次 + */ + public boolean existsRounds(BigInteger sessionId) { + Long count = jdbcTemplate.queryForObject( + "SELECT COUNT(1) FROM `" + ChatConstants.ROUND_TABLE + "` WHERE session_id=?", + Long.class, + sessionId + ); + return count != null && count > 0; + } + + /** + * 读取指定时间之后变更的轮次。 + * + * @param cursorTime 游标时间 + * @param cursorId 游标 ID + * @param limit 批大小 + * @return 轮次记录列表 + */ + public List loadModifiedAfter(Date cursorTime, BigInteger cursorId, int limit) { + Timestamp timestamp = cursorTime == null ? new Timestamp(0L) : new Timestamp(cursorTime.getTime()); + return jdbcTemplate.query( + "SELECT * FROM `" + ChatConstants.ROUND_TABLE + "` WHERE (modified > ?) OR (modified = ? AND id > ?) " + + "ORDER BY modified ASC, id ASC LIMIT ?", + rowMapper(), + timestamp, + timestamp, + cursorId == null ? BigInteger.ZERO : cursorId, + limit + ); + } + + private RowMapper rowMapper() { + return new RowMapper<>() { + @Override + public ChatRoundRecord mapRow(ResultSet rs, int rowNum) throws SQLException { + ChatRoundRecord record = new ChatRoundRecord(); + record.setId(bigInteger(rs, "id")); + record.setSessionId(bigInteger(rs, "session_id")); + record.setRoundNo(rs.getInt("round_no")); + record.setUserMessageId(bigInteger(rs, "user_message_id")); + record.setSelectedAssistantMessageId(bigInteger(rs, "selected_assistant_message_id")); + record.setSelectedVariantIndex(rs.getInt("selected_variant_index")); + record.setVariantCount(rs.getInt("variant_count")); + record.setStatus(rs.getString("status")); + record.setCreated(rs.getTimestamp("created")); + record.setModified(rs.getTimestamp("modified")); + return record; + } + }; + } + + private BigInteger bigInteger(ResultSet rs, String column) throws SQLException { + Object value = rs.getObject(column); + if (value == null) { + return null; + } + return new BigInteger(String.valueOf(value)); + } + + private Timestamp timestamp(Date value) { + return new Timestamp((value == null ? new Date() : value).getTime()); + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatSessionRepository.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatSessionRepository.java index dbae1fd..577c71a 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatSessionRepository.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/repository/mysql/MySqlChatSessionRepository.java @@ -15,6 +15,7 @@ import java.math.BigInteger; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Timestamp; +import java.sql.Types; import java.util.ArrayList; import java.util.Date; import java.util.LinkedHashMap; @@ -39,11 +40,12 @@ public class MySqlChatSessionRepository { } String table = tableRouter.resolveSessionTable(); String sql = "INSERT INTO `" + table + "` " + - "(id, tenant_id, dept_id, user_id, user_account, assistant_id, assistant_code, assistant_name, title, last_message_preview, message_count, access_at, created, created_by, modified, modified_by, is_deleted) " + - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, '', 0, ?, ?, ?, ?, ?, 0) " + + "(id, tenant_id, dept_id, user_id, user_account, assistant_id, assistant_code, assistant_name, title, ext_json, last_message_preview, message_count, access_at, created, created_by, modified, modified_by, is_deleted) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, '', 0, ?, ?, ?, ?, ?, 0) " + "ON DUPLICATE KEY UPDATE user_account=COALESCE(NULLIF(VALUES(user_account), ''), user_account), " + "assistant_id=VALUES(assistant_id), assistant_code=COALESCE(NULLIF(VALUES(assistant_code), ''), assistant_code), " + "assistant_name=COALESCE(NULLIF(VALUES(assistant_name), ''), assistant_name), " + + "ext_json=COALESCE(VALUES(ext_json), ext_json), " + "title=COALESCE(NULLIF(VALUES(title), ''), title), " + "access_at=VALUES(access_at), modified=VALUES(modified), modified_by=VALUES(modified_by), is_deleted=0"; jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> { @@ -57,11 +59,12 @@ public class MySqlChatSessionRepository { ps.setString(7, safeString(command.getAssistantCode())); ps.setString(8, safeString(command.getAssistantName())); ps.setString(9, safeString(command.getTitle())); - ps.setTimestamp(10, operateAt); + setNullableJson(ps, 10, command.getExtJson()); ps.setTimestamp(11, operateAt); - ps.setObject(12, command.getOperatorId()); - ps.setTimestamp(13, operateAt); - ps.setObject(14, command.getOperatorId()); + ps.setTimestamp(12, operateAt); + ps.setObject(13, command.getOperatorId()); + ps.setTimestamp(14, operateAt); + ps.setObject(15, command.getOperatorId()); }); } @@ -71,33 +74,38 @@ public class MySqlChatSessionRepository { } String table = tableRouter.resolveSessionTable(); String sql = "UPDATE `" + table + "` SET " + - "last_sender_id=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_id END, " + - "last_sender_name=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_name END, " + - "last_message_preview=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_preview END, " + - "last_message_at=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_at END, " + - "access_at=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE access_at END, " + + "last_sender_id=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_id END, " + + "last_sender_name=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_sender_name END, " + + "last_message_preview=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_preview END, " + + "last_message_at=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE last_message_at END, " + + "access_at=CASE WHEN ?=1 OR last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE access_at END, " + "message_count=COALESCE(message_count, 0) + ?, " + - "modified=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE modified END, " + - "modified_by=CASE WHEN last_message_at IS NULL OR last_message_at <= ? THEN ? ELSE modified_by END " + + "modified=?, modified_by=? " + "WHERE id=?"; jdbcTemplate.batchUpdate(sql, commands, commands.size(), (ps, command) -> { Timestamp lastMessageAt = timestamp(command.getLastMessageAt()); - ps.setTimestamp(1, lastMessageAt); - ps.setObject(2, command.getLastSenderId()); - ps.setTimestamp(3, lastMessageAt); - ps.setString(4, command.getLastSenderName()); + Timestamp accessAt = timestamp(command.getAccessAt()); + Timestamp modifiedAt = timestamp(command.getModifiedAt()); + int forceOverwrite = command.isForceOverwrite() ? 1 : 0; + ps.setInt(1, forceOverwrite); + ps.setTimestamp(2, lastMessageAt); + ps.setObject(3, command.getLastSenderId()); + ps.setInt(4, forceOverwrite); ps.setTimestamp(5, lastMessageAt); - ps.setString(6, command.getLastMessagePreview()); - ps.setTimestamp(7, lastMessageAt); + ps.setString(6, command.getLastSenderName()); + ps.setInt(7, forceOverwrite); ps.setTimestamp(8, lastMessageAt); - ps.setTimestamp(9, lastMessageAt); - ps.setTimestamp(10, lastMessageAt); - ps.setInt(11, Math.max(command.getMessageIncrement(), 1)); + ps.setString(9, safeString(command.getLastMessagePreview())); + ps.setInt(10, forceOverwrite); + ps.setTimestamp(11, lastMessageAt); ps.setTimestamp(12, lastMessageAt); - ps.setTimestamp(13, lastMessageAt); + ps.setInt(13, forceOverwrite); ps.setTimestamp(14, lastMessageAt); - ps.setObject(15, command.getOperatorId()); - ps.setObject(16, command.getSessionId()); + ps.setTimestamp(15, accessAt); + ps.setInt(16, Math.max(command.getMessageIncrement(), 0)); + ps.setTimestamp(17, modifiedAt); + ps.setObject(18, command.getOperatorId()); + ps.setObject(19, command.getSessionId()); }); } @@ -105,13 +113,13 @@ public class MySqlChatSessionRepository { String table = tableRouter.resolveSessionTable(); List params = new ArrayList<>(); StringBuilder sql = new StringBuilder("SELECT * FROM `").append(table) - .append("` WHERE user_id=? AND is_deleted=0"); + .append("` WHERE user_id=? AND is_deleted=0 AND last_message_at IS NOT NULL"); params.add(userId); if (assistantId != null) { sql.append(" AND assistant_id=?"); params.add(assistantId); } - sql.append(" ORDER BY access_at DESC, id DESC LIMIT ? OFFSET ?"); + sql.append(" ORDER BY last_message_at DESC, id DESC LIMIT ? OFFSET ?"); params.add(query.getPageSize()); params.add(query.getOffset()); return jdbcTemplate.query(sql.toString(), sessionRowMapper(), params.toArray()); @@ -121,7 +129,7 @@ public class MySqlChatSessionRepository { String table = tableRouter.resolveSessionTable(); List params = new ArrayList<>(); StringBuilder sql = new StringBuilder("SELECT COUNT(1) FROM `").append(table) - .append("` WHERE user_id=? AND is_deleted=0"); + .append("` WHERE user_id=? AND is_deleted=0 AND last_message_at IS NOT NULL"); params.add(userId); if (assistantId != null) { sql.append(" AND assistant_id=?"); @@ -248,6 +256,7 @@ public class MySqlChatSessionRepository { summary.setAssistantCode(rs.getString("assistant_code")); summary.setAssistantName(rs.getString("assistant_name")); summary.setTitle(rs.getString("title")); + summary.setExtJson(rs.getString("ext_json")); summary.setLastMessagePreview(rs.getString("last_message_preview")); summary.setLastSenderId(bigInteger(rs, "last_sender_id")); summary.setLastSenderName(rs.getString("last_sender_name")); @@ -279,4 +288,12 @@ public class MySqlChatSessionRepository { private String safeString(String value) { return value == null ? "" : value; } + + private void setNullableJson(java.sql.PreparedStatement ps, int parameterIndex, String value) throws SQLException { + if (value == null || value.isBlank()) { + ps.setNull(parameterIndex, Types.VARCHAR); + return; + } + ps.setString(parameterIndex, value); + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistDispatcher.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistDispatcher.java index 6012eb0..59318e4 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistDispatcher.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistDispatcher.java @@ -1,16 +1,22 @@ package tech.easyflow.chatlog.service; import org.slf4j.MDC; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; import tech.easyflow.chatlog.cache.ChatHotStateService; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; import tech.easyflow.chatlog.domain.event.ChatPersistEvent; import tech.easyflow.chatlog.domain.event.ChatPersistEventType; import tech.easyflow.chatlog.domain.event.payload.ChatSessionDeletePayload; import tech.easyflow.chatlog.domain.event.payload.ChatSessionRenamePayload; import tech.easyflow.chatlog.support.ChatJsonSupport; +import tech.easyflow.common.web.exceptions.BusinessException; import java.math.BigInteger; import java.util.Date; @@ -19,21 +25,26 @@ import java.util.UUID; @Service public class ChatPersistDispatcher { + private static final Logger log = LoggerFactory.getLogger(ChatPersistDispatcher.class); + private final ChatHotStateService chatHotStateService; private final ChatPersistEventProducer eventProducer; + private final ChatPersistMySqlApplyService mySqlApplyService; private final ChatJsonSupport chatJsonSupport; public ChatPersistDispatcher(ChatHotStateService chatHotStateService, ChatPersistEventProducer eventProducer, + ChatPersistMySqlApplyService mySqlApplyService, ChatJsonSupport chatJsonSupport) { this.chatHotStateService = chatHotStateService; this.eventProducer = eventProducer; + this.mySqlApplyService = mySqlApplyService; this.chatJsonSupport = chatJsonSupport; } public ChatSessionSummary createOrTouchSession(ChatSessionUpsertCommand command) { ChatSessionSummary summary = chatHotStateService.touchSession(command); - eventProducer.send(buildEvent( + ChatPersistEvent event = buildEvent( UUID.randomUUID().toString(), ChatPersistEventType.SESSION_PREPARED, command.getSessionId(), @@ -41,10 +52,43 @@ public class ChatPersistDispatcher { command.getAssistantId(), command.getOperateAt(), chatJsonSupport.toJson(command) - )); + ); + persistImmediately(event); + eventProducer.send(event); return summary; } + public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) { + ChatRoundRecord record = chatHotStateService.createOrTouchRound(command); + ChatPersistEvent event = buildEvent( + UUID.randomUUID().toString(), + ChatPersistEventType.ROUND_UPSERTED, + command.getSessionId(), + BigInteger.ZERO, + BigInteger.ZERO, + command.getOperateAt(), + chatJsonSupport.toJson(command) + ); + persistImmediately(event); + eventProducer.send(event); + return record; + } + + public void selectRoundVariant(ChatRoundSelectCommand command) { + chatHotStateService.selectVariant(command); + ChatPersistEvent event = buildEvent( + UUID.randomUUID().toString(), + ChatPersistEventType.ROUND_VARIANT_SELECTED, + command.getSessionId(), + BigInteger.ZERO, + BigInteger.ZERO, + command.getOperateAt(), + chatJsonSupport.toJson(command) + ); + persistImmediately(event); + eventProducer.send(event); + } + public void appendUserMessage(ChatAppendMessageCommand command) { appendMessage(command, ChatPersistEventType.USER_MESSAGE_APPENDED); } @@ -96,7 +140,7 @@ public class ChatPersistDispatcher { private void appendMessage(ChatAppendMessageCommand command, ChatPersistEventType eventType) { chatHotStateService.appendMessage(command); - eventProducer.send(buildEvent( + ChatPersistEvent event = buildEvent( eventId("message", command.getMessageId()), eventType, command.getSessionId(), @@ -104,7 +148,27 @@ public class ChatPersistDispatcher { command.getAssistantId(), command.getCreated(), chatJsonSupport.toJson(command) - )); + ); + persistImmediately(event); + eventProducer.send(event); + } + + /** + * 先同步写入 MySQL,再发送异步事件,保证会话列表和版本切换读取有确定来源。 + * + * @param event 持久化事件 + */ + private void persistImmediately(ChatPersistEvent event) { + try { + mySqlApplyService.apply(java.util.List.of(event)); + } catch (RuntimeException ex) { + log.error("聊天记录同步写入 MySQL 失败,eventId={}, eventType={}, sessionId={}", + event == null ? null : event.getEventId(), + event == null ? null : event.getEventType(), + event == null ? null : event.getSessionId(), + ex); + throw new BusinessException("聊天记录持久化失败,请稍后重试"); + } } private ChatPersistEvent buildEvent(String eventId, diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyService.java index fda6f3b..288a6e5 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyService.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyService.java @@ -3,6 +3,8 @@ package tech.easyflow.chatlog.service; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; import tech.easyflow.chatlog.domain.command.ChatSessionSummaryCommand; import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; import tech.easyflow.chatlog.domain.event.ChatPersistEvent; @@ -11,7 +13,9 @@ import tech.easyflow.chatlog.domain.event.payload.ChatSessionDeletePayload; import tech.easyflow.chatlog.domain.event.payload.ChatSessionRenamePayload; import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository; import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager; +import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository; import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository; +import tech.easyflow.chatlog.support.ChatConstants; import tech.easyflow.chatlog.support.ChatJsonSupport; import java.math.BigInteger; @@ -30,15 +34,18 @@ public class ChatPersistMySqlApplyService { private final MySqlChatSessionRepository sessionRepository; private final MySqlChatLogRepository logRepository; + private final MySqlChatRoundRepository roundRepository; private final MySqlChatLogTableManager tableManager; private final ChatJsonSupport chatJsonSupport; public ChatPersistMySqlApplyService(MySqlChatSessionRepository sessionRepository, MySqlChatLogRepository logRepository, + MySqlChatRoundRepository roundRepository, MySqlChatLogTableManager tableManager, ChatJsonSupport chatJsonSupport) { this.sessionRepository = sessionRepository; this.logRepository = logRepository; + this.roundRepository = roundRepository; this.tableManager = tableManager; this.chatJsonSupport = chatJsonSupport; } @@ -50,6 +57,8 @@ public class ChatPersistMySqlApplyService { } Map sessionUpserts = new LinkedHashMap<>(); + Map roundUpserts = new LinkedHashMap<>(); + List roundSelections = new ArrayList<>(); List appendCommands = new ArrayList<>(); Map summaryCommands = new LinkedHashMap<>(); List renamePayloads = new ArrayList<>(); @@ -67,6 +76,18 @@ public class ChatPersistMySqlApplyService { sessionUpserts.put(command.getSessionId(), command); } } + case ROUND_UPSERTED -> { + ChatRoundUpsertCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatRoundUpsertCommand.class); + if (command != null && command.getRoundId() != null) { + roundUpserts.put(command.getRoundId(), command); + } + } + case ROUND_VARIANT_SELECTED -> { + ChatRoundSelectCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatRoundSelectCommand.class); + if (command != null && command.getRoundId() != null) { + roundSelections.add(command); + } + } case USER_MESSAGE_APPENDED, ASSISTANT_MESSAGE_APPENDED -> { ChatAppendMessageCommand command = chatJsonSupport.fromJson(event.getPayload(), ChatAppendMessageCommand.class); if (command == null || command.getSessionId() == null || command.getMessageId() == null) { @@ -96,6 +117,9 @@ public class ChatPersistMySqlApplyService { if (!sessionUpserts.isEmpty()) { sessionRepository.createOrTouchBatch(new ArrayList<>(sessionUpserts.values())); } + if (!roundUpserts.isEmpty()) { + roundRepository.createOrTouchBatch(new ArrayList<>(roundUpserts.values())); + } if (!months.isEmpty()) { for (YearMonth month : months) { tableManager.ensureMonthTable(month); @@ -113,6 +137,9 @@ public class ChatPersistMySqlApplyService { } sessionRepository.updateSummaries(new ArrayList<>(summaryCommands.values())); } + if (!roundSelections.isEmpty()) { + roundRepository.selectVariants(roundSelections); + } if (!renamePayloads.isEmpty()) { sessionRepository.renameSessions(renamePayloads); } @@ -127,15 +154,26 @@ public class ChatPersistMySqlApplyService { ChatSessionSummaryCommand created = new ChatSessionSummaryCommand(); created.setSessionId(command.getSessionId()); created.setUserId(command.getUserId()); + created.setLastMessageAt(null); + created.setAccessAt(null); + created.setModifiedAt(null); created.setMessageIncrement(0); return created; }); summary.setMessageIncrement(summary.getMessageIncrement() + 1); - if (summary.getLastMessageAt() == null || !command.getCreated().before(summary.getLastMessageAt())) { + if (ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT.equals(command.getMessageKind()) + && command.getVariantIndex() != null + && command.getVariantIndex() > 1) { + summary.setMessageIncrement(Math.max(summary.getMessageIncrement() - 1, 0)); + } + Date commandCreated = defaultDate(command.getCreated()); + if (summary.getLastMessageAt() == null || !commandCreated.before(summary.getLastMessageAt())) { summary.setLastSenderId(command.getSenderId()); summary.setLastSenderName(command.getSenderName()); summary.setLastMessagePreview(trimPreview(command.getContentText())); - summary.setLastMessageAt(command.getCreated()); + summary.setLastMessageAt(commandCreated); + summary.setAccessAt(commandCreated); + summary.setModifiedAt(commandCreated); summary.setOperatorId(command.getCreatedBy()); } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundCommandService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundCommandService.java new file mode 100644 index 0000000..3ce1800 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundCommandService.java @@ -0,0 +1,26 @@ +package tech.easyflow.chatlog.service; + +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; + +/** + * 聊天轮次写服务。 + */ +public interface ChatRoundCommandService { + + /** + * 创建或更新轮次聚合。 + * + * @param command 轮次命令 + * @return 最新轮次记录 + */ + ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command); + + /** + * 切换轮次当前选中的答案版本。 + * + * @param command 切换命令 + */ + void selectVariant(ChatRoundSelectCommand command); +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundOperateService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundOperateService.java new file mode 100644 index 0000000..d311f0c --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundOperateService.java @@ -0,0 +1,42 @@ +package tech.easyflow.chatlog.service; + +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; + +import java.math.BigInteger; +import java.util.List; + +/** + * 聊天轮次业务操作服务。 + */ +public interface ChatRoundOperateService { + + /** + * 校验并返回允许重答的轮次。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @return 轮次记录 + */ + ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId); + + /** + * 查询轮次下所有答案版本。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @return 答案版本列表 + */ + List listVariants(BigInteger sessionId, BigInteger roundId); + + /** + * 切换指定轮次当前选中的答案版本。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @param variantIndex 目标版本序号 + * @param operatorId 操作人 + * @return 选中的答案消息 + */ + ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId); +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundQueryService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundQueryService.java new file mode 100644 index 0000000..97f0c5a --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatRoundQueryService.java @@ -0,0 +1,57 @@ +package tech.easyflow.chatlog.service; + +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; + +import java.math.BigInteger; +import java.util.List; + +/** + * 聊天轮次读服务。 + */ +public interface ChatRoundQueryService { + + /** + * 查询会话最新轮次。 + * + * @param sessionId 会话 ID + * @return 最新轮次 + */ + ChatRoundRecord getLatestRound(BigInteger sessionId); + + /** + * 查询指定轮次。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @return 轮次记录 + */ + ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId); + + /** + * 查询轮次下所有助手答案版本。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @return 答案版本列表 + */ + List listRoundVariants(BigInteger sessionId, BigInteger roundId); + + /** + * 查询轮次下指定答案版本。 + * + * @param sessionId 会话 ID + * @param roundId 轮次 ID + * @param variantIndex 答案版本序号 + * @return 答案版本记录 + */ + ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex); + + /** + * 判断会话是否已经启用轮次模型。 + * + * @param sessionId 会话 ID + * @return 是否存在轮次 + */ + boolean hasRounds(BigInteger sessionId); +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatSessionQueryService.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatSessionQueryService.java index 80ca64b..397b45e 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatSessionQueryService.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/ChatSessionQueryService.java @@ -1,6 +1,7 @@ package tech.easyflow.chatlog.service; import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatHistoryPage; import tech.easyflow.chatlog.domain.dto.ChatSessionPage; import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; import tech.easyflow.chatlog.domain.query.ChatPageQuery; @@ -18,5 +19,22 @@ public interface ChatSessionQueryService { ChatSessionSummary getSessionSummary(BigInteger sessionId); + /** + * 分页查询当前会话的主线可见消息。 + * + * @param sessionId 会话 ID + * @param query 分页参数 + * @return 主线消息分页 + */ + ChatHistoryPage pageMainlineMessages(BigInteger sessionId, ChatPageQuery query); + + /** + * 查询当前会话的全部主线可见消息。 + * + * @param sessionId 会话 ID + * @return 主线消息列表 + */ + List listMainlineMessages(BigInteger sessionId); + List getRecentTail(BigInteger sessionId, int limit); } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundCommandServiceImpl.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundCommandServiceImpl.java new file mode 100644 index 0000000..35f206e --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundCommandServiceImpl.java @@ -0,0 +1,31 @@ +package tech.easyflow.chatlog.service.impl; + +import org.springframework.stereotype.Service; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.service.ChatPersistDispatcher; +import tech.easyflow.chatlog.service.ChatRoundCommandService; + +/** + * 聊天轮次写服务实现。 + */ +@Service +public class ChatRoundCommandServiceImpl implements ChatRoundCommandService { + + private final ChatPersistDispatcher chatPersistDispatcher; + + public ChatRoundCommandServiceImpl(ChatPersistDispatcher chatPersistDispatcher) { + this.chatPersistDispatcher = chatPersistDispatcher; + } + + @Override + public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) { + return chatPersistDispatcher.createOrTouchRound(command); + } + + @Override + public void selectVariant(ChatRoundSelectCommand command) { + chatPersistDispatcher.selectRoundVariant(command); + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImpl.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImpl.java new file mode 100644 index 0000000..3fbc598 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImpl.java @@ -0,0 +1,101 @@ +package tech.easyflow.chatlog.service.impl; + +import org.springframework.stereotype.Service; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.service.ChatRoundCommandService; +import tech.easyflow.chatlog.service.ChatRoundOperateService; +import tech.easyflow.chatlog.service.ChatRoundQueryService; +import tech.easyflow.chatlog.support.ChatConstants; +import tech.easyflow.common.web.exceptions.BusinessException; + +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * 聊天轮次业务操作服务实现。 + */ +@Service +public class ChatRoundOperateServiceImpl implements ChatRoundOperateService { + + private final ChatRoundQueryService chatRoundQueryService; + private final ChatRoundCommandService chatRoundCommandService; + + public ChatRoundOperateServiceImpl(ChatRoundQueryService chatRoundQueryService, + ChatRoundCommandService chatRoundCommandService) { + this.chatRoundQueryService = chatRoundQueryService; + this.chatRoundCommandService = chatRoundCommandService; + } + + @Override + public ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId) { + ChatRoundRecord round = requireLatestRound(sessionId, roundId); + if (round.getSelectedAssistantMessageId() == null || round.getSelectedVariantIndex() == null + || round.getSelectedVariantIndex() <= 0) { + throw new BusinessException("当前轮次暂无可重答的回答"); + } + return round; + } + + @Override + public List listVariants(BigInteger sessionId, BigInteger roundId) { + ChatRoundRecord round = chatRoundQueryService.getRound(sessionId, roundId); + if (round == null) { + throw new BusinessException("轮次不存在"); + } + ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(sessionId); + boolean switchable = latestRound != null + && Objects.equals(latestRound.getId(), round.getId()) + && !ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus()); + List variants = new ArrayList<>(chatRoundQueryService.listRoundVariants(sessionId, roundId)); + for (ChatMessageRecord variant : variants) { + variant.setVariantCount(round.getVariantCount()); + variant.setSelectedVariantIndex(round.getSelectedVariantIndex()); + variant.setSwitchable(switchable); + } + return variants; + } + + @Override + public ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId) { + ChatRoundRecord round = requireLatestRound(sessionId, roundId); + if (variantIndex == null || variantIndex <= 0) { + throw new BusinessException("目标答案版本无效"); + } + ChatMessageRecord selected = chatRoundQueryService.getRoundVariant(sessionId, roundId, variantIndex); + if (selected == null) { + throw new BusinessException("目标答案版本不存在"); + } + + ChatRoundSelectCommand command = new ChatRoundSelectCommand(); + command.setSessionId(sessionId); + command.setRoundId(roundId); + command.setSelectedVariantIndex(variantIndex); + command.setSelectedAssistantMessageId(selected.getId()); + command.setSelectedAssistantMessage(selected); + command.setOperatorId(operatorId); + chatRoundCommandService.selectVariant(command); + selected.setSelectedVariantIndex(variantIndex); + selected.setVariantCount(round.getVariantCount()); + selected.setSwitchable(true); + return selected; + } + + private ChatRoundRecord requireLatestRound(BigInteger sessionId, BigInteger roundId) { + ChatRoundRecord round = chatRoundQueryService.getRound(sessionId, roundId); + if (round == null) { + throw new BusinessException("轮次不存在"); + } + ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(sessionId); + if (latestRound == null || !Objects.equals(latestRound.getId(), round.getId())) { + throw new BusinessException("当前轮次已有后续对话,不支持切换答案版本"); + } + if (ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(round.getStatus())) { + throw new BusinessException("当前轮次已有后续对话,不支持切换答案版本"); + } + return round; + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundQueryServiceImpl.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundQueryServiceImpl.java new file mode 100644 index 0000000..edb8c9a --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatRoundQueryServiceImpl.java @@ -0,0 +1,76 @@ +package tech.easyflow.chatlog.service.impl; + +import org.springframework.stereotype.Service; +import tech.easyflow.chatlog.cache.ChatHotStateService; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager; +import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository; +import tech.easyflow.chatlog.service.ChatRoundQueryService; + +import java.math.BigInteger; +import java.util.List; + +/** + * 聊天轮次读服务实现。 + */ +@Service +public class ChatRoundQueryServiceImpl implements ChatRoundQueryService { + + private final MySqlChatRoundRepository roundRepository; + private final MySqlChatLogRepository logRepository; + private final MySqlChatLogTableManager tableManager; + private final ChatHotStateService chatHotStateService; + + public ChatRoundQueryServiceImpl(MySqlChatRoundRepository roundRepository, + MySqlChatLogRepository logRepository, + MySqlChatLogTableManager tableManager, + ChatHotStateService chatHotStateService) { + this.roundRepository = roundRepository; + this.logRepository = logRepository; + this.tableManager = tableManager; + this.chatHotStateService = chatHotStateService; + } + + @Override + public ChatRoundRecord getLatestRound(BigInteger sessionId) { + ChatRoundRecord cached = chatHotStateService.getLatestRound(sessionId); + if (cached != null) { + return cached; + } + ChatRoundRecord record = roundRepository.findLatestRound(sessionId); + if (record != null) { + chatHotStateService.cacheRound(record); + } + return record; + } + + @Override + public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) { + ChatRoundRecord cached = chatHotStateService.getRound(sessionId, roundId); + if (cached != null) { + return cached; + } + ChatRoundRecord record = roundRepository.findRound(sessionId, roundId); + if (record != null) { + chatHotStateService.cacheRound(record); + } + return record; + } + + @Override + public List listRoundVariants(BigInteger sessionId, BigInteger roundId) { + return logRepository.listRoundVariants(sessionId, roundId, tableManager.listRecentExistingMonths(3)); + } + + @Override + public ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) { + return logRepository.findRoundVariant(sessionId, roundId, variantIndex, tableManager.listRecentExistingMonths(3)); + } + + @Override + public boolean hasRounds(BigInteger sessionId) { + return roundRepository.existsRounds(sessionId); + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImpl.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImpl.java index d33d425..b81759b 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImpl.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImpl.java @@ -2,6 +2,7 @@ package tech.easyflow.chatlog.service.impl; import org.springframework.stereotype.Service; import tech.easyflow.chatlog.cache.ChatHotStateService; +import tech.easyflow.chatlog.domain.dto.ChatHistoryPage; import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; import tech.easyflow.chatlog.domain.dto.ChatSessionPage; import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; @@ -12,7 +13,12 @@ import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository; import tech.easyflow.chatlog.service.ChatSessionQueryService; import java.math.BigInteger; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; @Service public class ChatSessionQueryServiceImpl implements ChatSessionQueryService { @@ -34,21 +40,7 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService { @Override public List listSessions(BigInteger userId, BigInteger assistantId, ChatPageQuery query) { - if (assistantId == null) { - List sessionIds = chatHotStateService.listSessionIds(userId, query.getOffset(), query.getPageSize()); - if (!sessionIds.isEmpty()) { - List cached = chatHotStateService.getSessionSummaries(sessionIds); - if (cached.size() == sessionIds.size()) { - return cached; - } - } - List sessions = sessionRepository.listSessions(userId, null, query); - chatHotStateService.cacheSessionSummaries(sessions); - return sessions; - } - List sessions = sessionRepository.listSessions(userId, assistantId, query); - chatHotStateService.cacheSessionSummaries(sessions); - return sessions; + return sessionRepository.listSessions(userId, assistantId, query); } @Override @@ -62,21 +54,6 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService { page.setPageNumber(query.getPageNumber()); page.setPageSize(query.getPageSize()); - if (assistantId == null && chatHotStateService.hasSessionIndex(userId)) { - List sessionIds = chatHotStateService.listSessionIds(userId, query.getOffset(), query.getPageSize()); - if (sessionIds.isEmpty()) { - page.setTotal(chatHotStateService.countSessions(userId)); - page.setRecords(List.of()); - return page; - } - List cached = chatHotStateService.getSessionSummaries(sessionIds); - if (cached.size() == sessionIds.size()) { - page.setTotal(chatHotStateService.countSessions(userId)); - page.setRecords(cached); - return page; - } - } - page.setTotal(sessionRepository.countSessions(userId, assistantId)); page.setRecords(listSessions(userId, assistantId, query)); return page; @@ -95,14 +72,74 @@ public class ChatSessionQueryServiceImpl implements ChatSessionQueryService { return summary; } + @Override + public ChatHistoryPage pageMainlineMessages(BigInteger sessionId, ChatPageQuery query) { + ChatHistoryPage page = new ChatHistoryPage(); + page.setPageNumber(query.getPageNumber()); + page.setPageSize(query.getPageSize()); + ChatSessionSummary summary = getSessionSummary(sessionId); + long total = summary == null || summary.getMessageCount() == null ? 0L : summary.getMessageCount(); + List records = logRepository.listMainlineMessages( + sessionId, + tableManager.listRecentExistingMonths(3), + query.getOffset(), + Math.toIntExact(query.getPageSize()) + ); + page.setRecords(records); + page.setTotal(Math.max(total, query.getOffset() + records.size())); + return page; + } + + @Override + public List listMainlineMessages(BigInteger sessionId) { + return logRepository.listMainlineMessages(sessionId, tableManager.listRecentExistingMonths(3)); + } + @Override public List getRecentTail(BigInteger sessionId, int limit) { List cached = chatHotStateService.getSessionTail(sessionId); - if (cached != null) { + if (cached != null && isTailReliable(cached)) { return cached.subList(0, Math.min(limit, cached.size())); } List records = logRepository.listRecentTail(sessionId, tableManager.listRecentExistingMonths(3), limit); chatHotStateService.setSessionTail(sessionId, records); return records; } + + /** + * 校验 Redis tail 是否符合当前主线版本语义,防止过期选中版本把可见回答过滤掉。 + * + * @param records Redis tail 消息 + * @return true 表示可直接使用缓存 + */ + private boolean isTailReliable(List records) { + Map selectedVariantByRound = new LinkedHashMap<>(); + Map> assistantVariantsByRound = new LinkedHashMap<>(); + for (ChatMessageRecord record : records) { + if (record == null || record.getRoundId() == null) { + continue; + } + Integer selectedVariantIndex = record.getSelectedVariantIndex(); + if (selectedVariantIndex != null && selectedVariantIndex > 0) { + Integer previous = selectedVariantByRound.putIfAbsent(record.getRoundId(), selectedVariantIndex); + if (previous != null && !Objects.equals(previous, selectedVariantIndex)) { + return false; + } + } + if ("assistant".equalsIgnoreCase(record.getSenderRole()) + && record.getVariantIndex() != null + && record.getVariantIndex() > 0) { + assistantVariantsByRound + .computeIfAbsent(record.getRoundId(), key -> new LinkedHashSet<>()) + .add(record.getVariantIndex()); + } + } + for (Map.Entry entry : selectedVariantByRound.entrySet()) { + Set visibleVariants = assistantVariantsByRound.get(entry.getKey()); + if (visibleVariants != null && !visibleVariants.isEmpty() && !visibleVariants.contains(entry.getValue())) { + return false; + } + } + return true; + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListener.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListener.java index 345c751..988466d 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListener.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListener.java @@ -4,11 +4,19 @@ import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.domain.dto.ChatSessionExtPayload; import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; import tech.easyflow.chatlog.service.ChatPersistDispatcher; +import tech.easyflow.chatlog.service.ChatRoundOperateService; +import tech.easyflow.chatlog.service.ChatRoundQueryService; import tech.easyflow.chatlog.service.ChatSessionQueryService; +import tech.easyflow.chatlog.support.ChatConstants; +import tech.easyflow.chatlog.support.ChatJsonSupport; import tech.easyflow.common.web.exceptions.BusinessException; +import tech.easyflow.core.runtime.ChatRuntimeExtKeys; import tech.easyflow.core.runtime.ChatRuntimeHistoryPayloadHelper; import tech.easyflow.core.runtime.ChatRuntimeContext; import tech.easyflow.core.runtime.ChatRuntimeListener; @@ -27,12 +35,21 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { private final SnowFlakeIDKeyGenerator idGenerator = new SnowFlakeIDKeyGenerator(); private final ChatPersistDispatcher chatPersistDispatcher; + private final ChatRoundOperateService chatRoundOperateService; + private final ChatRoundQueryService chatRoundQueryService; private final ChatSessionQueryService chatSessionQueryService; + private final ChatJsonSupport chatJsonSupport; public ChatlogRuntimeListener(ChatPersistDispatcher chatPersistDispatcher, - ChatSessionQueryService chatSessionQueryService) { + ChatRoundOperateService chatRoundOperateService, + ChatRoundQueryService chatRoundQueryService, + ChatSessionQueryService chatSessionQueryService, + ChatJsonSupport chatJsonSupport) { this.chatPersistDispatcher = chatPersistDispatcher; + this.chatRoundOperateService = chatRoundOperateService; + this.chatRoundQueryService = chatRoundQueryService; this.chatSessionQueryService = chatSessionQueryService; + this.chatJsonSupport = chatJsonSupport; } @Override @@ -48,6 +65,7 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { command.setAssistantCode(context.getAssistantCode()); command.setAssistantName(context.getAssistantName()); command.setTitle(context.getSessionTitle()); + command.setExtJson(resolveExtJson(context)); command.setOperatorId(defaultNumber(context.getUserId())); chatPersistDispatcher.createOrTouchSession(command); } catch (RuntimeException ex) { @@ -58,6 +76,9 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { @Override public void onUserMessage(ChatRuntimeContext context, ChatRuntimeMessage message) { try { + if (prepareRoundContext(context, message)) { + return; + } chatPersistDispatcher.appendUserMessage(toAppendCommand(context, message)); } catch (RuntimeException ex) { throw persistFailed(ex); @@ -67,7 +88,36 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { @Override public void onAssistantCompleted(ChatRuntimeContext context, ChatRuntimeMessage message) { try { + applyAssistantRoundMetadata(context, message); chatPersistDispatcher.appendAssistantMessage(toAppendCommand(context, message)); + chatPersistDispatcher.createOrTouchRound(buildAssistantCompletedRoundCommand(context, message)); + } catch (RuntimeException ex) { + throw persistFailed(ex); + } + } + + @Override + public void onChatFailed(ChatRuntimeContext context, Throwable throwable) { + try { + BigInteger roundId = resolveNumber(context, ChatRuntimeExtKeys.CURRENT_ROUND_ID); + if (context == null || context.getSessionId() == null || roundId == null) { + return; + } + ChatRoundRecord currentRound = chatRoundQueryService.getRound(context.getSessionId(), roundId); + if (currentRound == null) { + return; + } + ChatRoundUpsertCommand command = new ChatRoundUpsertCommand(); + command.setRoundId(currentRound.getId()); + command.setSessionId(currentRound.getSessionId()); + command.setRoundNo(currentRound.getRoundNo()); + command.setUserMessageId(currentRound.getUserMessageId()); + command.setSelectedAssistantMessageId(currentRound.getSelectedAssistantMessageId()); + command.setSelectedVariantIndex(currentRound.getSelectedVariantIndex()); + command.setVariantCount(currentRound.getVariantCount()); + command.setStatus(ChatConstants.ROUND_STATUS_READY); + command.setOperatorId(defaultNumber(context.getUserId())); + chatPersistDispatcher.createOrTouchRound(command); } catch (RuntimeException ex) { throw persistFailed(ex); } @@ -78,7 +128,15 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { if (context == null || context.getSessionId() == null || limit <= 0) { return Collections.emptyList(); } - List records = new ArrayList<>(chatSessionQueryService.getRecentTail(context.getSessionId(), limit)); + BigInteger regenerateRoundId = resolveNumber(context, ChatRuntimeExtKeys.REGENERATE_ROUND_ID); + int queryLimit = regenerateRoundId == null ? limit : limit + 4; + List records = new ArrayList<>(chatSessionQueryService.getRecentTail(context.getSessionId(), queryLimit)); + if (regenerateRoundId != null) { + records.removeIf(record -> regenerateRoundId.equals(record.getRoundId())); + if (records.size() > limit) { + records = new ArrayList<>(records.subList(0, limit)); + } + } Collections.reverse(records); List messages = new ArrayList<>(records.size()); for (ChatMessageRecord record : records) { @@ -118,11 +176,127 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { command.setContentType(message.getContentType()); command.setContentText(message.getContentText()); command.setContentPayload(message.getContentPayload()); + command.setRoundId(message.getRoundId()); + command.setRoundNo(message.getRoundNo()); + command.setMessageKind(message.getMessageKind()); + command.setVariantIndex(message.getVariantIndex()); command.setCreatedBy(defaultNumber(context.getUserId())); command.setCreated(message.getCreatedAt()); return command; } + private boolean prepareRoundContext(ChatRuntimeContext context, ChatRuntimeMessage message) { + if (context == null || message == null || context.getSessionId() == null) { + return false; + } + BigInteger regenerateRoundId = resolveNumber(context, ChatRuntimeExtKeys.REGENERATE_ROUND_ID); + if (regenerateRoundId != null) { + ChatRoundRecord round = chatRoundOperateService.requireRegeneratableRound(context.getSessionId(), regenerateRoundId); + context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_ID, round.getId()); + context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_NO, round.getRoundNo()); + context.getExt().put(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX, Math.max(round.getVariantCount() + 1, 1)); + ChatRoundUpsertCommand command = new ChatRoundUpsertCommand(); + command.setRoundId(round.getId()); + command.setSessionId(round.getSessionId()); + command.setRoundNo(round.getRoundNo()); + command.setUserMessageId(round.getUserMessageId()); + command.setSelectedAssistantMessageId(round.getSelectedAssistantMessageId()); + command.setSelectedVariantIndex(round.getSelectedVariantIndex()); + command.setVariantCount(round.getVariantCount()); + command.setStatus(ChatConstants.ROUND_STATUS_ANSWERING); + command.setOperatorId(defaultNumber(context.getUserId())); + chatPersistDispatcher.createOrTouchRound(command); + return true; + } + ChatRoundRecord latestRound = chatRoundQueryService.getLatestRound(context.getSessionId()); + if (latestRound != null && latestRound.getId() != null + && !ChatConstants.ROUND_STATUS_LOCKED.equalsIgnoreCase(latestRound.getStatus())) { + ChatRoundUpsertCommand lockCommand = new ChatRoundUpsertCommand(); + lockCommand.setRoundId(latestRound.getId()); + lockCommand.setSessionId(latestRound.getSessionId()); + lockCommand.setRoundNo(latestRound.getRoundNo()); + lockCommand.setUserMessageId(latestRound.getUserMessageId()); + lockCommand.setSelectedAssistantMessageId(latestRound.getSelectedAssistantMessageId()); + lockCommand.setSelectedVariantIndex(latestRound.getSelectedVariantIndex()); + lockCommand.setVariantCount(latestRound.getVariantCount()); + lockCommand.setStatus(ChatConstants.ROUND_STATUS_LOCKED); + lockCommand.setOperatorId(defaultNumber(context.getUserId())); + chatPersistDispatcher.createOrTouchRound(lockCommand); + } + BigInteger roundId = BigInteger.valueOf(idGenerator.nextId()); + int roundNo = latestRound == null || latestRound.getRoundNo() == null ? 1 : latestRound.getRoundNo() + 1; + if (message.getMessageId() == null) { + message.setMessageId(BigInteger.valueOf(idGenerator.nextId())); + } + context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_ID, roundId); + context.getExt().put(ChatRuntimeExtKeys.CURRENT_ROUND_NO, roundNo); + context.getExt().put(ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX, 1); + message.setRoundId(roundId); + message.setRoundNo(roundNo); + message.setMessageKind(ChatConstants.MESSAGE_KIND_USER_PROMPT); + message.setVariantIndex(null); + ChatRoundUpsertCommand command = new ChatRoundUpsertCommand(); + command.setRoundId(roundId); + command.setSessionId(context.getSessionId()); + command.setRoundNo(roundNo); + command.setUserMessageId(message.getMessageId()); + command.setSelectedVariantIndex(0); + command.setVariantCount(0); + command.setStatus(ChatConstants.ROUND_STATUS_ANSWERING); + command.setOperatorId(defaultNumber(context.getUserId())); + chatPersistDispatcher.createOrTouchRound(command); + return false; + } + + private void applyAssistantRoundMetadata(ChatRuntimeContext context, ChatRuntimeMessage message) { + if (message.getMessageId() == null) { + message.setMessageId(BigInteger.valueOf(idGenerator.nextId())); + } + message.setRoundId(resolveNumber(context, ChatRuntimeExtKeys.CURRENT_ROUND_ID)); + message.setRoundNo(resolveInteger(context, ChatRuntimeExtKeys.CURRENT_ROUND_NO)); + message.setVariantIndex(resolveInteger(context, ChatRuntimeExtKeys.CURRENT_VARIANT_INDEX)); + message.setMessageKind(ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT); + } + + private ChatRoundUpsertCommand buildAssistantCompletedRoundCommand(ChatRuntimeContext context, ChatRuntimeMessage message) { + ChatRoundUpsertCommand command = new ChatRoundUpsertCommand(); + command.setRoundId(message.getRoundId()); + command.setSessionId(context.getSessionId()); + command.setRoundNo(message.getRoundNo()); + ChatRoundRecord existing = chatRoundQueryService.getRound(context.getSessionId(), message.getRoundId()); + if (existing != null) { + command.setUserMessageId(existing.getUserMessageId()); + } + command.setSelectedAssistantMessageId(message.getMessageId()); + command.setSelectedVariantIndex(message.getVariantIndex()); + command.setVariantCount(message.getVariantIndex()); + command.setStatus(ChatConstants.ROUND_STATUS_READY); + command.setOperatorId(defaultNumber(context.getUserId())); + return command; + } + + private BigInteger resolveNumber(ChatRuntimeContext context, String key) { + if (context == null || context.getExt() == null || key == null) { + return null; + } + Object value = context.getExt().get(key); + if (value == null) { + return null; + } + return new BigInteger(String.valueOf(value)); + } + + private Integer resolveInteger(ChatRuntimeContext context, String key) { + if (context == null || context.getExt() == null || key == null) { + return null; + } + Object value = context.getExt().get(key); + if (value == null) { + return null; + } + return Integer.parseInt(String.valueOf(value)); + } + private BigInteger defaultNumber(BigInteger value) { return value == null ? BigInteger.ZERO : value; } @@ -133,4 +307,27 @@ public class ChatlogRuntimeListener implements ChatRuntimeListener { } return new BusinessException("聊天记录持久化失败,请稍后重试"); } + + private String resolveExtJson(ChatRuntimeContext context) { + if (context == null || context.getExt() == null || context.getExt().isEmpty()) { + return null; + } + if (!context.getExt().containsKey(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS)) { + return null; + } + Object rawExtraKnowledgeIds = context.getExt().get(ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS); + if (!(rawExtraKnowledgeIds instanceof List rawList)) { + return null; + } + List extraKnowledgeIds = new ArrayList<>(); + for (Object item : rawList) { + if (item == null) { + continue; + } + extraKnowledgeIds.add(new BigInteger(String.valueOf(item))); + } + ChatSessionExtPayload payload = new ChatSessionExtPayload(); + payload.setExtraKnowledgeIds(extraKnowledgeIds); + return chatJsonSupport.toJson(payload); + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/support/ChatConstants.java b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/support/ChatConstants.java index 4cf66be..b0d232d 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/support/ChatConstants.java +++ b/easyflow-modules/easyflow-module-chatlog/src/main/java/tech/easyflow/chatlog/support/ChatConstants.java @@ -3,12 +3,18 @@ package tech.easyflow.chatlog.support; public final class ChatConstants { public static final String SESSION_TABLE = "chat_session"; + public static final String ROUND_TABLE = "chat_round"; public static final String CHAT_LOG_TEMPLATE = "chat_log_template"; public static final String CHAT_LOG_PREFIX = "chat_log_"; public static final String CHAT_PERSIST_TOPIC = "chat-persist"; public static final String CHAT_PERSIST_GROUP = "chat-persist-group"; public static final String CHECKPOINT_SYNC_CODE_SESSION = "chat_session_sync"; public static final String CHECKPOINT_SYNC_CODE_LOG = "chat_log_sync"; + public static final String ROUND_STATUS_ANSWERING = "ANSWERING"; + public static final String ROUND_STATUS_READY = "READY"; + public static final String ROUND_STATUS_LOCKED = "LOCKED"; + public static final String MESSAGE_KIND_USER_PROMPT = "USER_PROMPT"; + public static final String MESSAGE_KIND_ASSISTANT_VARIANT = "ASSISTANT_VARIANT"; private ChatConstants() { } diff --git a/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyServiceTest.java b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyServiceTest.java index a5f6b25..53d9390 100644 --- a/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyServiceTest.java +++ b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/ChatPersistMySqlApplyServiceTest.java @@ -1,18 +1,29 @@ package tech.easyflow.chatlog.service; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.Assert; import org.junit.Test; import tech.easyflow.chatlog.domain.command.ChatAppendMessageCommand; +import tech.easyflow.chatlog.domain.command.ChatSessionSummaryCommand; import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; +import tech.easyflow.chatlog.domain.event.ChatPersistEvent; +import tech.easyflow.chatlog.domain.event.ChatPersistEventType; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager; +import tech.easyflow.chatlog.repository.mysql.MySqlChatRoundRepository; +import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository; +import tech.easyflow.chatlog.support.ChatJsonSupport; import java.math.BigInteger; +import java.time.YearMonth; +import java.util.ArrayList; import java.util.Date; import java.util.List; public class ChatPersistMySqlApplyServiceTest { private final ChatPersistMySqlApplyService service = - new ChatPersistMySqlApplyService(null, null, null, null); + new ChatPersistMySqlApplyService(null, null, null, null, new ChatJsonSupport(new ObjectMapper())); @Test public void shouldBuildMissingSessionUpsertFromMessageMetadata() { @@ -69,4 +80,101 @@ public class ChatPersistMySqlApplyServiceTest { Assert.assertEquals("会话-202", upsert.getTitle()); Assert.assertEquals(BigInteger.valueOf(7), upsert.getOperatorId()); } + + @Test + public void shouldNotDoubleCountSummaryWhenMessageEventReplayed() { + ChatJsonSupport jsonSupport = new ChatJsonSupport(new ObjectMapper()); + FakeSessionRepository sessionRepository = new FakeSessionRepository(); + FakeLogRepository logRepository = new FakeLogRepository(jsonSupport); + ChatPersistMySqlApplyService applyService = new ChatPersistMySqlApplyService( + sessionRepository, + logRepository, + new FakeRoundRepository(), + new FakeTableManager(), + jsonSupport + ); + ChatAppendMessageCommand command = new ChatAppendMessageCommand(); + command.setMessageId(BigInteger.valueOf(301)); + command.setSessionId(BigInteger.valueOf(401)); + command.setTenantId(BigInteger.ONE); + command.setDeptId(BigInteger.ONE); + command.setUserId(BigInteger.valueOf(7)); + command.setAssistantId(BigInteger.valueOf(8)); + command.setSenderId(BigInteger.valueOf(7)); + command.setSenderName("admin"); + command.setSenderRole("user"); + command.setContentText("第一条消息"); + command.setCreatedBy(BigInteger.valueOf(7)); + command.setCreated(new Date(4_000L)); + + ChatPersistEvent event = new ChatPersistEvent(); + event.setEventId("message-301"); + event.setEventType(ChatPersistEventType.USER_MESSAGE_APPENDED); + event.setSessionId(command.getSessionId()); + event.setPayload(jsonSupport.toJson(command)); + + applyService.apply(List.of(event)); + applyService.apply(List.of(event)); + + Assert.assertEquals(1, sessionRepository.summaryCommands.size()); + ChatSessionSummaryCommand summaryCommand = sessionRepository.summaryCommands.get(0); + Assert.assertEquals(1, summaryCommand.getMessageIncrement()); + Assert.assertEquals("第一条消息", summaryCommand.getLastMessagePreview()); + Assert.assertEquals(new Date(4_000L), summaryCommand.getLastMessageAt()); + Assert.assertEquals(new Date(4_000L), summaryCommand.getAccessAt()); + } + + private static final class FakeSessionRepository extends MySqlChatSessionRepository { + + private final List summaryCommands = new ArrayList<>(); + + private FakeSessionRepository() { + super(null, null); + } + + @Override + public void createOrTouchBatch(List commands) { + } + + @Override + public void updateSummaries(List commands) { + summaryCommands.addAll(commands); + } + } + + private static final class FakeLogRepository extends MySqlChatLogRepository { + + private boolean inserted; + + private FakeLogRepository(ChatJsonSupport jsonSupport) { + super(null, null, jsonSupport); + } + + @Override + public List appendMessages(List commands) { + if (inserted) { + return List.of(); + } + inserted = true; + return commands; + } + } + + private static final class FakeRoundRepository extends MySqlChatRoundRepository { + + private FakeRoundRepository() { + super(null); + } + } + + private static final class FakeTableManager extends MySqlChatLogTableManager { + + private FakeTableManager() { + super(null, null); + } + + @Override + public void ensureMonthTable(YearMonth month) { + } + } } diff --git a/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImplTest.java b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImplTest.java new file mode 100644 index 0000000..e5077cd --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatRoundOperateServiceImplTest.java @@ -0,0 +1,206 @@ +package tech.easyflow.chatlog.service.impl; + +import org.junit.Assert; +import org.junit.Test; +import tech.easyflow.chatlog.domain.command.ChatRoundSelectCommand; +import tech.easyflow.chatlog.domain.command.ChatRoundUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.service.ChatRoundCommandService; +import tech.easyflow.chatlog.service.ChatRoundQueryService; +import tech.easyflow.chatlog.support.ChatConstants; +import tech.easyflow.common.web.exceptions.BusinessException; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +/** + * {@link ChatRoundOperateServiceImpl} 单元测试。 + */ +public class ChatRoundOperateServiceImplTest { + + /** + * 切换答案版本时应精准查询目标版本,避免先加载全部版本再过滤。 + */ + @Test + public void selectVariantShouldReadTargetVariantDirectly() { + FakeRoundQueryService queryService = new FakeRoundQueryService(); + queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_READY); + queryService.latestRound = queryService.round; + queryService.targetVariant = message(BigInteger.valueOf(3002), 2); + FakeRoundCommandService commandService = new FakeRoundCommandService(); + ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, commandService); + + ChatMessageRecord selected = service.selectVariant( + BigInteger.valueOf(1001), + BigInteger.valueOf(2001), + 2, + BigInteger.valueOf(7) + ); + + Assert.assertEquals(BigInteger.valueOf(3002), selected.getId()); + Assert.assertEquals(Integer.valueOf(2), selected.getSelectedVariantIndex()); + Assert.assertEquals(Integer.valueOf(2), selected.getVariantCount()); + Assert.assertEquals(Boolean.TRUE, selected.getSwitchable()); + Assert.assertEquals(1, queryService.getRoundVariantCalls); + Assert.assertEquals(0, queryService.listRoundVariantsCalls); + Assert.assertNotNull(commandService.selectedCommand); + Assert.assertEquals(BigInteger.valueOf(3002), commandService.selectedCommand.getSelectedAssistantMessageId()); + } + + /** + * 列出答案版本时应由业务层统一补齐当前选中态和可切换状态。 + */ + @Test + public void listVariantsShouldFillVariantMetadata() { + FakeRoundQueryService queryService = new FakeRoundQueryService(); + queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_READY); + queryService.latestRound = queryService.round; + queryService.variants = List.of(message(BigInteger.valueOf(3001), 1), message(BigInteger.valueOf(3002), 2)); + ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, new FakeRoundCommandService()); + + List variants = service.listVariants(BigInteger.valueOf(1001), BigInteger.valueOf(2001)); + + Assert.assertEquals(2, variants.size()); + for (ChatMessageRecord variant : variants) { + Assert.assertEquals(Integer.valueOf(2), variant.getVariantCount()); + Assert.assertEquals(Integer.valueOf(2), variant.getSelectedVariantIndex()); + Assert.assertEquals(Boolean.TRUE, variant.getSwitchable()); + } + } + + /** + * 已锁定轮次禁止切换,避免改变已有后续上下文。 + */ + @Test(expected = BusinessException.class) + public void selectVariantShouldRejectLockedRound() { + FakeRoundQueryService queryService = new FakeRoundQueryService(); + queryService.round = round(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 2, ChatConstants.ROUND_STATUS_LOCKED); + queryService.latestRound = queryService.round; + ChatRoundOperateServiceImpl service = new ChatRoundOperateServiceImpl(queryService, new FakeRoundCommandService()); + + service.selectVariant(BigInteger.valueOf(1001), BigInteger.valueOf(2001), 1, BigInteger.valueOf(7)); + } + + /** + * 新增迁移必须为热表版本切换查询补齐索引。 + * + * @throws Exception 读取迁移文件失败时抛出 + */ + @Test + public void migrationShouldCreateRoundVariantIndex() throws Exception { + String sql = Files.readString( + resolveMigrationPath("V18__mysql_chat_round_variant_index.sql"), + StandardCharsets.UTF_8 + ); + + Assert.assertTrue(sql.contains("idx_chat_log_round_variant")); + Assert.assertTrue(sql.contains("`session_id`, `round_id`, `message_kind`, `variant_index`, `created`, `id`")); + Assert.assertFalse(sql.contains("V16__mysql_chat_round_variant")); + } + + /** + * 从当前测试工作目录向上查找迁移文件,兼容根工程与模块工程两种运行方式。 + * + * @param fileName 迁移文件名 + * @return 迁移文件路径 + * @throws Exception 未找到迁移文件时抛出 + */ + private static Path resolveMigrationPath(String fileName) throws Exception { + Path current = Path.of("").toAbsolutePath(); + while (current != null) { + Path candidate = current.resolve( + "easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/" + fileName + ); + if (Files.exists(candidate)) { + return candidate; + } + current = current.getParent(); + } + throw new java.nio.file.NoSuchFileException(fileName); + } + + private static ChatRoundRecord round(BigInteger sessionId, BigInteger roundId, int selectedVariantIndex, String status) { + ChatRoundRecord round = new ChatRoundRecord(); + round.setId(roundId); + round.setSessionId(sessionId); + round.setRoundNo(1); + round.setSelectedVariantIndex(selectedVariantIndex); + round.setVariantCount(2); + round.setStatus(status); + return round; + } + + private static ChatMessageRecord message(BigInteger id, int variantIndex) { + ChatMessageRecord record = new ChatMessageRecord(); + record.setId(id); + record.setSessionId(BigInteger.valueOf(1001)); + record.setRoundId(BigInteger.valueOf(2001)); + record.setVariantIndex(variantIndex); + record.setSenderRole("assistant"); + record.setMessageKind(ChatConstants.MESSAGE_KIND_ASSISTANT_VARIANT); + record.setContentText("答案 " + variantIndex); + return record; + } + + /** + * 轮次读服务测试替身。 + */ + private static final class FakeRoundQueryService implements ChatRoundQueryService { + + private ChatRoundRecord round; + private ChatRoundRecord latestRound; + private ChatMessageRecord targetVariant; + private List variants = List.of(); + private int getRoundVariantCalls; + private int listRoundVariantsCalls; + + @Override + public ChatRoundRecord getLatestRound(BigInteger sessionId) { + return latestRound; + } + + @Override + public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) { + return round; + } + + @Override + public List listRoundVariants(BigInteger sessionId, BigInteger roundId) { + listRoundVariantsCalls += 1; + return variants; + } + + @Override + public ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) { + getRoundVariantCalls += 1; + return targetVariant; + } + + @Override + public boolean hasRounds(BigInteger sessionId) { + return round != null; + } + } + + /** + * 轮次写服务测试替身。 + */ + private static final class FakeRoundCommandService implements ChatRoundCommandService { + + private ChatRoundSelectCommand selectedCommand; + + @Override + public ChatRoundRecord createOrTouchRound(ChatRoundUpsertCommand command) { + return null; + } + + @Override + public void selectVariant(ChatRoundSelectCommand command) { + selectedCommand = command; + } + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImplTest.java b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImplTest.java new file mode 100644 index 0000000..10bc784 --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatSessionQueryServiceImplTest.java @@ -0,0 +1,222 @@ +package tech.easyflow.chatlog.service.impl; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Assert; +import org.junit.Test; +import tech.easyflow.chatlog.cache.ChatHotStateService; +import tech.easyflow.chatlog.config.ChatCacheProperties; +import tech.easyflow.chatlog.domain.dto.ChatHistoryPage; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatSessionPage; +import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; +import tech.easyflow.chatlog.domain.query.ChatPageQuery; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogRepository; +import tech.easyflow.chatlog.repository.mysql.MySqlChatLogTableManager; +import tech.easyflow.chatlog.repository.mysql.MySqlChatSessionRepository; +import tech.easyflow.chatlog.support.ChatJsonSupport; + +import java.math.BigInteger; +import java.time.YearMonth; +import java.util.ArrayList; +import java.util.List; + +/** + * {@link ChatSessionQueryServiceImpl} 单元测试。 + */ +public class ChatSessionQueryServiceImplTest { + + /** + * 会话列表必须以 MySQL 会话表为唯一权威来源,不再使用 Redis 列表索引。 + */ + @Test + public void pageSessionsShouldUseMysqlRepositoryAsAuthority() { + FakeSessionRepository sessionRepository = new FakeSessionRepository(); + sessionRepository.sessions = List.of(session(BigInteger.valueOf(1001), 4)); + sessionRepository.count = 1; + ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl( + sessionRepository, + new FakeLogRepository(), + new FakeTableManager(List.of()), + new FakeHotStateService() + ); + + ChatSessionPage page = service.pageSessions(BigInteger.valueOf(7), null, new ChatPageQuery()); + + Assert.assertEquals(1, page.getTotal()); + Assert.assertEquals(1, page.getRecords().size()); + Assert.assertEquals(1, sessionRepository.countSessionsCalls); + Assert.assertEquals(1, sessionRepository.listSessionsCalls); + } + + /** + * 工作台消息分页必须走 MySQL 热表主线查询,并保持分页参数语义。 + */ + @Test + public void pageMainlineMessagesShouldReadMysqlHotTables() { + FakeSessionRepository sessionRepository = new FakeSessionRepository(); + sessionRepository.summary = session(BigInteger.valueOf(2001), 6); + FakeLogRepository logRepository = new FakeLogRepository(); + logRepository.records = List.of(message(3001), message(3002)); + List months = List.of(YearMonth.of(2026, 5)); + ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl( + sessionRepository, + logRepository, + new FakeTableManager(months), + new FakeHotStateService() + ); + ChatPageQuery query = new ChatPageQuery(); + query.setPageNumber(2); + query.setPageSize(2); + + ChatHistoryPage page = service.pageMainlineMessages(BigInteger.valueOf(2001), query); + + Assert.assertEquals(6, page.getTotal()); + Assert.assertEquals(2, page.getRecords().size()); + Assert.assertEquals(BigInteger.valueOf(2001), logRepository.capturedSessionId); + Assert.assertEquals(months, logRepository.capturedMonths); + Assert.assertEquals(2, logRepository.capturedOffset); + Assert.assertEquals(2, logRepository.capturedLimit); + } + + /** + * 当 MySQL 摘要计数滞后时,分页 total 至少覆盖当前已返回的数据范围。 + */ + @Test + public void pageMainlineMessagesShouldNotReturnTotalSmallerThanCurrentPage() { + FakeSessionRepository sessionRepository = new FakeSessionRepository(); + sessionRepository.summary = session(BigInteger.valueOf(2002), 1); + FakeLogRepository logRepository = new FakeLogRepository(); + logRepository.records = List.of(message(4001), message(4002)); + ChatSessionQueryServiceImpl service = new ChatSessionQueryServiceImpl( + sessionRepository, + logRepository, + new FakeTableManager(List.of(YearMonth.of(2026, 5))), + new FakeHotStateService() + ); + ChatPageQuery query = new ChatPageQuery(); + query.setPageNumber(2); + query.setPageSize(2); + + ChatHistoryPage page = service.pageMainlineMessages(BigInteger.valueOf(2002), query); + + Assert.assertEquals(4, page.getTotal()); + } + + private static ChatSessionSummary session(BigInteger id, int messageCount) { + ChatSessionSummary summary = new ChatSessionSummary(); + summary.setId(id); + summary.setUserId(BigInteger.valueOf(7)); + summary.setMessageCount(messageCount); + return summary; + } + + private static ChatMessageRecord message(long id) { + ChatMessageRecord record = new ChatMessageRecord(); + record.setId(BigInteger.valueOf(id)); + return record; + } + + /** + * MySQL 会话仓储测试替身。 + */ + private static final class FakeSessionRepository extends MySqlChatSessionRepository { + + private long count; + private int countSessionsCalls; + private int listSessionsCalls; + private ChatSessionSummary summary; + private List sessions = new ArrayList<>(); + + private FakeSessionRepository() { + super(null, null); + } + + @Override + public List listSessions(BigInteger userId, BigInteger assistantId, ChatPageQuery query) { + listSessionsCalls += 1; + return sessions; + } + + @Override + public long countSessions(BigInteger userId, BigInteger assistantId) { + countSessionsCalls += 1; + return count; + } + + @Override + public ChatSessionSummary findBySessionId(BigInteger sessionId) { + return summary; + } + } + + /** + * MySQL 消息仓储测试替身。 + */ + private static final class FakeLogRepository extends MySqlChatLogRepository { + + private BigInteger capturedSessionId; + private List capturedMonths; + private long capturedOffset; + private int capturedLimit; + private List records = new ArrayList<>(); + + private FakeLogRepository() { + super(null, null, new ChatJsonSupport(new ObjectMapper())); + } + + @Override + public List listMainlineMessages(BigInteger sessionId, List months, long offset, int limit) { + capturedSessionId = sessionId; + capturedMonths = months; + capturedOffset = offset; + capturedLimit = limit; + return records; + } + } + + /** + * MySQL 热表管理器测试替身。 + */ + private static final class FakeTableManager extends MySqlChatLogTableManager { + + private final List months; + + private FakeTableManager(List months) { + super(null, null); + this.months = months; + } + + @Override + public List listRecentExistingMonths(int retentionMonths) { + return months; + } + } + + /** + * Redis 热态测试替身,避免单测依赖真实 Redis。 + */ + private static final class FakeHotStateService extends ChatHotStateService { + + private FakeHotStateService() { + super(null, new ObjectMapper(), new ChatCacheProperties()); + } + + @Override + public ChatSessionSummary getSessionSummary(BigInteger sessionId) { + return null; + } + + @Override + public void cacheSessionSummary(ChatSessionSummary summary) { + } + + @Override + public List getSessionTail(BigInteger sessionId) { + return null; + } + + @Override + public void setSessionTail(BigInteger sessionId, List records) { + } + } +} diff --git a/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListenerTest.java b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListenerTest.java new file mode 100644 index 0000000..01dd67e --- /dev/null +++ b/easyflow-modules/easyflow-module-chatlog/src/test/java/tech/easyflow/chatlog/service/impl/ChatlogRuntimeListenerTest.java @@ -0,0 +1,219 @@ +package tech.easyflow.chatlog.service.impl; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Assert; +import org.junit.Test; +import tech.easyflow.chatlog.domain.dto.ChatRoundRecord; +import tech.easyflow.chatlog.domain.command.ChatSessionUpsertCommand; +import tech.easyflow.chatlog.domain.dto.ChatMessageRecord; +import tech.easyflow.chatlog.domain.dto.ChatSessionExtPayload; +import tech.easyflow.chatlog.domain.dto.ChatSessionSummary; +import tech.easyflow.chatlog.service.ChatPersistDispatcher; +import tech.easyflow.chatlog.service.ChatRoundOperateService; +import tech.easyflow.chatlog.service.ChatRoundQueryService; +import tech.easyflow.chatlog.service.ChatSessionQueryService; +import tech.easyflow.chatlog.support.ChatJsonSupport; +import tech.easyflow.core.runtime.ChatRuntimeContext; +import tech.easyflow.core.runtime.ChatRuntimeExtKeys; + +import java.math.BigInteger; +import java.util.Date; +import java.util.List; + +/** + * {@link ChatlogRuntimeListener} 单元测试。 + */ +public class ChatlogRuntimeListenerTest { + + /** + * 会话准备阶段应把额外知识库选择写入 ext_json。 + */ + @Test + public void onSessionPreparedShouldWriteExtraKnowledgeIdsToExtJson() { + CapturingChatPersistDispatcher dispatcher = new CapturingChatPersistDispatcher(); + ChatlogRuntimeListener listener = new ChatlogRuntimeListener( + dispatcher, + new NoopChatRoundOperateService(), + new NoopChatRoundQueryService(), + new NoopChatSessionQueryService(), + new ChatJsonSupport(new ObjectMapper()) + ); + ChatRuntimeContext context = new ChatRuntimeContext(); + context.setSessionId(BigInteger.valueOf(1001)); + context.setTenantId(BigInteger.ONE); + context.setDeptId(BigInteger.TEN); + context.setUserId(BigInteger.valueOf(7)); + context.setUserAccount("admin"); + context.setAssistantId(BigInteger.valueOf(88)); + context.setAssistantCode("bot-88"); + context.setAssistantName("测试助手"); + context.setSessionTitle("你好"); + context.getExt().put( + ChatRuntimeExtKeys.EXTRA_KNOWLEDGE_IDS, + List.of(BigInteger.valueOf(11), BigInteger.valueOf(22)) + ); + + listener.onSessionPrepared(context); + + Assert.assertNotNull(dispatcher.captured); + ChatSessionExtPayload payload = new ChatJsonSupport(new ObjectMapper()) + .fromJson(dispatcher.captured.getExtJson(), ChatSessionExtPayload.class); + Assert.assertEquals( + List.of(BigInteger.valueOf(11), BigInteger.valueOf(22)), + payload.getExtraKnowledgeIds() + ); + } + + /** + * 重新生成时历史上下文应排除当前轮旧问题和旧答案。 + */ + @Test + public void loadMessagesShouldExcludeRegenerateRoundHistory() { + ChatlogRuntimeListener listener = new ChatlogRuntimeListener( + null, + new NoopChatRoundOperateService(), + new NoopChatRoundQueryService(), + new TailChatSessionQueryService(List.of( + record(4, 2, "assistant", "旧答案"), + record(3, 2, "user", "当前问题"), + record(2, 1, "assistant", "上一轮答案"), + record(1, 1, "user", "上一轮问题") + )), + new ChatJsonSupport(new ObjectMapper()) + ); + ChatRuntimeContext context = new ChatRuntimeContext(); + context.setSessionId(BigInteger.valueOf(1001)); + context.getExt().put(ChatRuntimeExtKeys.REGENERATE_ROUND_ID, BigInteger.valueOf(2)); + + List messages = listener.loadMessages(context, 10); + + Assert.assertEquals(2, messages.size()); + Assert.assertEquals("上一轮问题", messages.get(0).getContentText()); + Assert.assertEquals("上一轮答案", messages.get(1).getContentText()); + } + + private static ChatMessageRecord record(long id, int roundId, String role, String text) { + ChatMessageRecord record = new ChatMessageRecord(); + record.setId(BigInteger.valueOf(id)); + record.setSessionId(BigInteger.valueOf(1001)); + record.setRoundId(BigInteger.valueOf(roundId)); + record.setSenderRole(role); + record.setContentType("TEXT"); + record.setContentText(text); + record.setCreated(new Date(id)); + return record; + } + + private static class CapturingChatPersistDispatcher extends ChatPersistDispatcher { + + private ChatSessionUpsertCommand captured; + + private CapturingChatPersistDispatcher() { + super(null, null, null, null); + } + + @Override + public ChatSessionSummary createOrTouchSession(ChatSessionUpsertCommand command) { + this.captured = command; + return new ChatSessionSummary(); + } + } + + private static class NoopChatRoundOperateService implements ChatRoundOperateService { + + @Override + public ChatRoundRecord requireRegeneratableRound(BigInteger sessionId, BigInteger roundId) { + return null; + } + + @Override + public List listVariants(BigInteger sessionId, BigInteger roundId) { + return List.of(); + } + + @Override + public tech.easyflow.chatlog.domain.dto.ChatMessageRecord selectVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex, BigInteger operatorId) { + return null; + } + } + + private static class NoopChatRoundQueryService implements ChatRoundQueryService { + + @Override + public ChatRoundRecord getLatestRound(BigInteger sessionId) { + return null; + } + + @Override + public ChatRoundRecord getRound(BigInteger sessionId, BigInteger roundId) { + return null; + } + + @Override + public List listRoundVariants(BigInteger sessionId, BigInteger roundId) { + return List.of(); + } + + @Override + public tech.easyflow.chatlog.domain.dto.ChatMessageRecord getRoundVariant(BigInteger sessionId, BigInteger roundId, Integer variantIndex) { + return null; + } + + @Override + public boolean hasRounds(BigInteger sessionId) { + return false; + } + } + + private static class NoopChatSessionQueryService implements ChatSessionQueryService { + + @Override + public List listSessions(BigInteger userId, BigInteger assistantId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) { + return List.of(); + } + + @Override + public long countSessions(BigInteger userId, BigInteger assistantId) { + return 0; + } + + @Override + public tech.easyflow.chatlog.domain.dto.ChatSessionPage pageSessions(BigInteger userId, BigInteger assistantId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) { + return new tech.easyflow.chatlog.domain.dto.ChatSessionPage(); + } + + @Override + public ChatSessionSummary getSessionSummary(BigInteger sessionId) { + return null; + } + + @Override + public tech.easyflow.chatlog.domain.dto.ChatHistoryPage pageMainlineMessages(BigInteger sessionId, tech.easyflow.chatlog.domain.query.ChatPageQuery query) { + return new tech.easyflow.chatlog.domain.dto.ChatHistoryPage(); + } + + @Override + public List listMainlineMessages(BigInteger sessionId) { + return List.of(); + } + + @Override + public List getRecentTail(BigInteger sessionId, int limit) { + return List.of(); + } + } + + private static class TailChatSessionQueryService extends NoopChatSessionQueryService { + + private final List records; + + private TailChatSessionQueryService(List records) { + this.records = records; + } + + @Override + public List getRecentTail(BigInteger sessionId, int limit) { + return records.subList(0, Math.min(records.size(), limit)); + } + } +} diff --git a/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V14__mysql_chat_session_ext_json.sql b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V14__mysql_chat_session_ext_json.sql new file mode 100644 index 0000000..dff408c --- /dev/null +++ b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V14__mysql_chat_session_ext_json.sql @@ -0,0 +1,2 @@ +ALTER TABLE `chat_session` + ADD COLUMN `ext_json` json NULL COMMENT '会话扩展信息' AFTER `title`; diff --git a/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V15__mysql_chat_workspace_menu.sql b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V15__mysql_chat_workspace_menu.sql new file mode 100644 index 0000000..4b2039e --- /dev/null +++ b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V15__mysql_chat_workspace_menu.sql @@ -0,0 +1,20 @@ +SET NAMES utf8mb4; + +INSERT INTO `tb_sys_menu` ( + `id`, `parent_id`, `menu_type`, `menu_title`, `menu_url`, `component`, `menu_icon`, + `is_show`, `permission_tag`, `sort_no`, `status`, `created`, `created_by`, `modified`, `modified_by`, `remark` +) +SELECT + 367200000000000001, 0, 0, 'menus.ai.chat', '/ai/chat', '/ai/chat/index', 'svg:talk', + 1, '', 12, 0, '2026-05-12 10:00:00', 1, '2026-05-12 10:00:00', 1, '管理端聊天工作台菜单' +FROM DUAL +WHERE NOT EXISTS ( + SELECT 1 FROM `tb_sys_menu` WHERE `id` = 367200000000000001 +); + +INSERT INTO `tb_sys_role_menu` (`id`, `role_id`, `menu_id`) +SELECT 367200000000000101, 1, 367200000000000001 +FROM DUAL +WHERE NOT EXISTS ( + SELECT 1 FROM `tb_sys_role_menu` WHERE `id` = 367200000000000101 +); diff --git a/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V16__mysql_chat_round_variant.sql b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V16__mysql_chat_round_variant.sql new file mode 100644 index 0000000..0e42e85 --- /dev/null +++ b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V16__mysql_chat_round_variant.sql @@ -0,0 +1,162 @@ +CREATE TABLE IF NOT EXISTS `chat_round` +( + `id` bigint UNSIGNED NOT NULL COMMENT '轮次ID', + `session_id` bigint UNSIGNED NOT NULL COMMENT '会话ID', + `round_no` int NOT NULL COMMENT '轮次序号', + `user_message_id` bigint UNSIGNED NULL DEFAULT NULL COMMENT '用户消息ID', + `selected_assistant_message_id` bigint UNSIGNED NULL DEFAULT NULL COMMENT '当前选中的助手答案消息ID', + `selected_variant_index` int NOT NULL DEFAULT 0 COMMENT '当前选中的答案版本序号', + `variant_count` int NOT NULL DEFAULT 0 COMMENT '答案版本总数', + `status` varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL DEFAULT 'READY' COMMENT '轮次状态', + `created` datetime NOT NULL COMMENT '创建时间', + `modified` datetime NOT NULL COMMENT '修改时间', + PRIMARY KEY (`id`) USING BTREE, + UNIQUE KEY `uk_chat_round_session_round_no` (`session_id`, `round_no`) USING BTREE, + KEY `idx_chat_round_session_modified` (`session_id`, `modified`, `id`) USING BTREE +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci COMMENT = '聊天轮次热数据表'; + +SET @chat_log_template_alter = ( + SELECT CASE + WHEN COUNT(1) = 0 THEN 'SELECT 1' + ELSE CONCAT( + 'ALTER TABLE `chat_log_template` ', + GROUP_CONCAT(stmt ORDER BY ord SEPARATOR ', ') + ) + END + FROM ( + SELECT 1 AS ord, + 'ADD COLUMN `round_id` bigint UNSIGNED NULL DEFAULT NULL COMMENT ''轮次ID'' AFTER `assistant_id`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = 'chat_log_template' + AND column_name = 'round_id' + ) + UNION ALL + SELECT 2 AS ord, + 'ADD COLUMN `round_no` int NULL DEFAULT NULL COMMENT ''轮次序号'' AFTER `round_id`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = 'chat_log_template' + AND column_name = 'round_no' + ) + UNION ALL + SELECT 3 AS ord, + 'ADD COLUMN `message_kind` varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT ''消息类型'' AFTER `sender_role`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = 'chat_log_template' + AND column_name = 'message_kind' + ) + UNION ALL + SELECT 4 AS ord, + 'ADD COLUMN `variant_index` int NULL DEFAULT NULL COMMENT ''答案版本序号'' AFTER `message_kind`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = 'chat_log_template' + AND column_name = 'variant_index' + ) + ) changes +); + +PREPARE stmt_chat_log_template_alter FROM @chat_log_template_alter; +EXECUTE stmt_chat_log_template_alter; +DEALLOCATE PREPARE stmt_chat_log_template_alter; + +DROP PROCEDURE IF EXISTS migrate_chat_round_log_columns; + +DELIMITER $$ + +CREATE PROCEDURE migrate_chat_round_log_columns() +BEGIN + DECLARE done INT DEFAULT 0; + DECLARE v_table_name varchar(128); + DECLARE v_sql LONGTEXT; + DECLARE table_cursor CURSOR FOR + SELECT table_name + FROM information_schema.tables + WHERE table_schema = DATABASE() + AND table_name LIKE 'chat_log\\_%' + AND table_name <> 'chat_log_template'; + DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = 1; + + OPEN table_cursor; + + table_loop: + LOOP + FETCH table_cursor INTO v_table_name; + IF done = 1 THEN + LEAVE table_loop; + END IF; + + SET v_sql = ( + SELECT CASE + WHEN COUNT(1) = 0 THEN 'SELECT 1' + ELSE CONCAT( + 'ALTER TABLE `', v_table_name, '` ', + GROUP_CONCAT(stmt ORDER BY ord SEPARATOR ', ') + ) + END + FROM ( + SELECT 1 AS ord, + 'ADD COLUMN `round_id` bigint UNSIGNED NULL DEFAULT NULL COMMENT ''轮次ID'' AFTER `assistant_id`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = v_table_name + AND column_name = 'round_id' + ) + UNION ALL + SELECT 2 AS ord, + 'ADD COLUMN `round_no` int NULL DEFAULT NULL COMMENT ''轮次序号'' AFTER `round_id`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = v_table_name + AND column_name = 'round_no' + ) + UNION ALL + SELECT 3 AS ord, + 'ADD COLUMN `message_kind` varchar(32) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL COMMENT ''消息类型'' AFTER `sender_role`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = v_table_name + AND column_name = 'message_kind' + ) + UNION ALL + SELECT 4 AS ord, + 'ADD COLUMN `variant_index` int NULL DEFAULT NULL COMMENT ''答案版本序号'' AFTER `message_kind`' AS stmt + WHERE NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = DATABASE() + AND table_name = v_table_name + AND column_name = 'variant_index' + ) + ) changes + ); + + SET @chat_round_log_table_alter = v_sql; + PREPARE stmt_chat_round_log_alter FROM @chat_round_log_table_alter; + EXECUTE stmt_chat_round_log_alter; + DEALLOCATE PREPARE stmt_chat_round_log_alter; + END LOOP; + + CLOSE table_cursor; +END $$ + +DELIMITER ; + +CALL migrate_chat_round_log_columns(); +DROP PROCEDURE IF EXISTS migrate_chat_round_log_columns; diff --git a/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V18__mysql_chat_round_variant_index.sql b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V18__mysql_chat_round_variant_index.sql new file mode 100644 index 0000000..ed255f9 --- /dev/null +++ b/easyflow-starter/easyflow-starter-all/src/main/resources/db/migration/mysql/V18__mysql_chat_round_variant_index.sql @@ -0,0 +1,58 @@ +SET @chat_log_template_round_variant_index = ( + SELECT CASE + WHEN COUNT(1) > 0 THEN 'SELECT 1' + ELSE 'ALTER TABLE `chat_log_template` ADD INDEX `idx_chat_log_round_variant` (`session_id`, `round_id`, `message_kind`, `variant_index`, `created`, `id`)' + END + FROM information_schema.statistics + WHERE table_schema = DATABASE() + AND table_name = 'chat_log_template' + AND index_name = 'idx_chat_log_round_variant' +); + +PREPARE stmt_chat_log_template_round_variant_index FROM @chat_log_template_round_variant_index; +EXECUTE stmt_chat_log_template_round_variant_index; +DEALLOCATE PREPARE stmt_chat_log_template_round_variant_index; + +DROP PROCEDURE IF EXISTS migrate_chat_log_round_variant_index; + +CREATE PROCEDURE migrate_chat_log_round_variant_index() +BEGIN + DECLARE done INT DEFAULT 0; + DECLARE v_table_name VARCHAR(128); + DECLARE table_cursor CURSOR FOR + SELECT table_name + FROM information_schema.tables + WHERE table_schema = DATABASE() + AND table_name LIKE 'chat_log\_%' + AND table_name <> 'chat_log_template'; + DECLARE CONTINUE HANDLER FOR NOT FOUND SET done = 1; + + OPEN table_cursor; + read_loop: LOOP + FETCH table_cursor INTO v_table_name; + IF done = 1 THEN + LEAVE read_loop; + END IF; + + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.statistics + WHERE table_schema = DATABASE() + AND table_name = v_table_name + AND index_name = 'idx_chat_log_round_variant' + ) THEN + SET @chat_log_round_variant_index = CONCAT( + 'ALTER TABLE `', + v_table_name, + '` ADD INDEX `idx_chat_log_round_variant` (`session_id`, `round_id`, `message_kind`, `variant_index`, `created`, `id`)' + ); + PREPARE stmt_chat_log_round_variant_index FROM @chat_log_round_variant_index; + EXECUTE stmt_chat_log_round_variant_index; + DEALLOCATE PREPARE stmt_chat_log_round_variant_index; + END IF; + END LOOP; + CLOSE table_cursor; +END; + +CALL migrate_chat_log_round_variant_index(); +DROP PROCEDURE IF EXISTS migrate_chat_log_round_variant_index; diff --git a/easyflow-ui-admin/app/src/api/request.ts b/easyflow-ui-admin/app/src/api/request.ts index 2f75251..859eb02 100644 --- a/easyflow-ui-admin/app/src/api/request.ts +++ b/easyflow-ui-admin/app/src/api/request.ts @@ -197,6 +197,28 @@ export class SseClient { return; } + const contentType = res.headers.get('content-type') || ''; + if (!contentType.includes('text/event-stream')) { + let errorMessage = '请求失败,请稍后再试'; + try { + const body = await res.json(); + errorMessage = + body?.error ?? body?.message ?? body?.data?.message ?? errorMessage; + } catch { + try { + const text = await res.text(); + if (text.trim()) { + errorMessage = text.trim(); + } + } catch { + // ignore body parse failures and keep fallback message + } + } + showErrorOnce(errorMessage); + options?.onError?.(new Error(errorMessage)); + return; + } + // 在开始事件流之前检查是否还是同一个请求 if (this.currentRequestId !== currentRequestId) { return; diff --git a/easyflow-ui-admin/app/src/components/chat-workspace/ChatAnswerVariantNavigator.vue b/easyflow-ui-admin/app/src/components/chat-workspace/ChatAnswerVariantNavigator.vue new file mode 100644 index 0000000..d989888 --- /dev/null +++ b/easyflow-ui-admin/app/src/components/chat-workspace/ChatAnswerVariantNavigator.vue @@ -0,0 +1,147 @@ + + + + + diff --git a/easyflow-ui-admin/app/src/components/chat-workspace/ChatContextCapsuleBar.vue b/easyflow-ui-admin/app/src/components/chat-workspace/ChatContextCapsuleBar.vue new file mode 100644 index 0000000..83cf8b2 --- /dev/null +++ b/easyflow-ui-admin/app/src/components/chat-workspace/ChatContextCapsuleBar.vue @@ -0,0 +1,429 @@ + + + + + diff --git a/easyflow-ui-admin/app/src/components/chat-workspace/ChatMessageActionBar.vue b/easyflow-ui-admin/app/src/components/chat-workspace/ChatMessageActionBar.vue new file mode 100644 index 0000000..4e80c0c --- /dev/null +++ b/easyflow-ui-admin/app/src/components/chat-workspace/ChatMessageActionBar.vue @@ -0,0 +1,153 @@ + + + + + diff --git a/easyflow-ui-admin/app/src/components/chat-workspace/ChatWelcomeAssistantPicker.vue b/easyflow-ui-admin/app/src/components/chat-workspace/ChatWelcomeAssistantPicker.vue new file mode 100644 index 0000000..168073a --- /dev/null +++ b/easyflow-ui-admin/app/src/components/chat-workspace/ChatWelcomeAssistantPicker.vue @@ -0,0 +1,199 @@ + + + + + diff --git a/easyflow-ui-admin/app/src/locales/langs/en-US/menus.json b/easyflow-ui-admin/app/src/locales/langs/en-US/menus.json index 779dcb9..73a4dbb 100644 --- a/easyflow-ui-admin/app/src/locales/langs/en-US/menus.json +++ b/easyflow-ui-admin/app/src/locales/langs/en-US/menus.json @@ -22,6 +22,7 @@ "oauth": "OAuth" }, "ai": { + "chat": "Chat", "bots": "ChatAssistant", "title": "AI", "resources": "Resources", diff --git a/easyflow-ui-admin/app/src/locales/langs/zh-CN/menus.json b/easyflow-ui-admin/app/src/locales/langs/zh-CN/menus.json index 414899b..a9bf1ba 100644 --- a/easyflow-ui-admin/app/src/locales/langs/zh-CN/menus.json +++ b/easyflow-ui-admin/app/src/locales/langs/zh-CN/menus.json @@ -22,6 +22,7 @@ "oauth": "认证设置" }, "ai": { + "chat": "聊天", "bots": "聊天助手", "title": "AI能力", "resources": "素材库", diff --git a/easyflow-ui-admin/app/src/views/ai/chat/index.vue b/easyflow-ui-admin/app/src/views/ai/chat/index.vue new file mode 100644 index 0000000..57c07b9 --- /dev/null +++ b/easyflow-ui-admin/app/src/views/ai/chat/index.vue @@ -0,0 +1,2571 @@ + + + + + diff --git a/easyflow-ui-admin/packages/types/src/chat-time.ts b/easyflow-ui-admin/packages/types/src/chat-time.ts index ea48ad8..0e3f0a9 100644 --- a/easyflow-ui-admin/packages/types/src/chat-time.ts +++ b/easyflow-ui-admin/packages/types/src/chat-time.ts @@ -2,14 +2,31 @@ type ChatTimeTimelineRole = 'assistant' | 'tool' | 'user'; type ChatTimeToolStatus = 'TOOL_CALL' | 'TOOL_RESULT'; type ChatTimeThinkingStatus = 'end' | 'thinking'; +interface ChatTimeRoundMeta { + messageKind?: string; + roundId?: number | string; + roundNo?: number; + selectedVariantIndex?: number; + switchable?: boolean; + variantCount?: number; + variantIndex?: number; +} + interface ChatTimeTimelineItemBase { created: number | string; id: string; loading?: boolean; + messageKind?: string; placement: 'end' | 'start'; + roundId?: string; + roundNo?: number; role: ChatTimeTimelineRole; + selectedVariantIndex?: number; senderName?: string; + switchable?: boolean; typing?: boolean; + variantCount?: number; + variantIndex?: number; } interface ChatTimeAssistantThinkingSegment { @@ -66,14 +83,22 @@ interface ChatTimeHistoryRecord { loading?: boolean; placement?: 'end' | 'start'; role?: string; + roundId?: number | string; + roundNo?: number; + selectedVariantIndex?: number; senderName?: string; senderRole?: string; + switchable?: boolean; typing?: boolean; + variantCount?: number; + variantIndex?: number; + messageKind?: string; } -interface ChatTimeToolMutationPayload { +interface ChatTimeToolMutationPayload extends ChatTimeRoundMeta { created?: number | string; name?: string; + regenerate?: boolean; result?: any; toolCallId?: string; value?: any; @@ -85,6 +110,7 @@ export type { ChatTimeAssistantTextSegment, ChatTimeAssistantThinkingSegment, ChatTimeHistoryRecord, + ChatTimeRoundMeta, ChatTimeThinkingStatus, ChatTimeTimelineItem, ChatTimeTimelineItemBase, diff --git a/easyflow-ui-admin/packages/utils/src/helpers/chat-time.ts b/easyflow-ui-admin/packages/utils/src/helpers/chat-time.ts index 825938e..35268e3 100644 --- a/easyflow-ui-admin/packages/utils/src/helpers/chat-time.ts +++ b/easyflow-ui-admin/packages/utils/src/helpers/chat-time.ts @@ -1,8 +1,10 @@ import type { ChatTimeAssistantItem, ChatTimeHistoryRecord, + ChatTimeRoundMeta, ChatTimeThinkingStatus, ChatTimeTimelineItem, + ChatTimeTimelineItemBase, ChatTimeToolItem, ChatTimeToolMutationPayload, ChatTimeToolStatus, @@ -28,17 +30,71 @@ class ChatTimeTimelineBuilder { content?: string; created?: number | string; id?: string; + messageKind?: string; + roundId?: number | string; + roundNo?: number; senderName?: string; }, ) { - items.push({ + const item: ChatTimeTimelineItem = { content: normalizePlainText(payload.content), created: normalizeTimestamp(payload.created), id: payload.id || uuid(), placement: 'end', role: 'user', senderName: payload.senderName, - }); + }; + applyRoundMeta(item, payload); + items.push(item); + } + + /** + * 将最新一条待绑定的用户消息补齐到当前轮次。 + */ + static bindLatestPendingUserMessage( + items: ChatTimeTimelineItem[], + meta?: ChatTimeRoundMeta, + ) { + const roundId = normalizeRoundId(meta?.roundId); + if (!roundId) { + return; + } + for (let index = items.length - 1; index >= 0; index -= 1) { + const item = items[index]; + if (!item) { + continue; + } + if (item.role !== 'user') { + continue; + } + if (item.roundId) { + return; + } + applyRoundMeta(item, { + roundId, + roundNo: meta?.roundNo, + }); + return; + } + } + + /** + * 更新指定轮次的可切换状态。 + */ + static setRoundSwitchable( + items: ChatTimeTimelineItem[], + roundId: number | string | undefined, + switchable: boolean, + ) { + const normalizedRoundId = normalizeRoundId(roundId); + if (!normalizedRoundId) { + return; + } + for (const item of items) { + if (item.roundId === normalizedRoundId && item.role !== 'user') { + item.switchable = switchable; + } + } } /** @@ -48,12 +104,14 @@ class ChatTimeTimelineBuilder { items: ChatTimeTimelineItem[], delta?: string, created?: number | string, + meta?: ChatTimeRoundMeta, ) { const normalizedDelta = normalizePlainText(delta); if (!normalizedDelta) { return; } - const assistant = ensureAssistantTail(items, created); + prepareRoundVariant(items, meta); + const assistant = ensureAssistantTail(items, created, meta); const tail = assistant.segments[assistant.segments.length - 1]; if (tail?.type === 'thinking' && tail.status === 'thinking') { tail.content += normalizedDelta; @@ -77,12 +135,14 @@ class ChatTimeTimelineBuilder { items: ChatTimeTimelineItem[], delta?: string, created?: number | string, + meta?: ChatTimeRoundMeta, ) { const normalizedDelta = normalizeAssistantText(delta); if (!normalizedDelta) { return; } - const assistant = ensureAssistantTail(items, created); + prepareRoundVariant(items, meta); + const assistant = ensureAssistantTail(items, created, meta); stopThinkingForAssistant(assistant); const tail = assistant.segments[assistant.segments.length - 1]; if (tail?.type === 'text') { @@ -117,12 +177,14 @@ class ChatTimeTimelineBuilder { items: ChatTimeTimelineItem[], payload: ChatTimeToolMutationPayload, ) { + prepareRoundVariant(items, payload); this.stopThinking(items); const toolItem = ensureToolItem( items, payload.toolCallId, payload.created, payload.name, + payload, ); toolItem.arguments = normalizePayloadValue(payload.value); toolItem.content = ''; @@ -136,11 +198,13 @@ class ChatTimeTimelineBuilder { items: ChatTimeTimelineItem[], payload: ChatTimeToolMutationPayload, ) { + prepareRoundVariant(items, payload); const toolItem = ensureToolItem( items, payload.toolCallId, payload.created, payload.name, + payload, ); toolItem.result = normalizePayloadValue(payload.result); toolItem.content = toolItem.result; @@ -178,7 +242,7 @@ class ChatTimeTimelineBuilder { * 结束当前轮的 assistant 状态。 */ static finalize(items: ChatTimeTimelineItem[]) { - const last = items[items.length - 1]; + const last = findLastAssistant(items); if (!isAssistantItem(last)) { return; } @@ -186,6 +250,26 @@ class ChatTimeTimelineBuilder { last.loading = false; last.typing = false; } + + /** + * 按轮次替换当前主线可见的 assistant/tool 片段。 + */ + static replaceRoundMessages( + items: ChatTimeTimelineItem[], + roundId: number | string | undefined, + nextMessages: ChatTimeTimelineItem[], + ) { + const normalizedRoundId = normalizeRoundId(roundId); + if (!normalizedRoundId) { + return; + } + const range = resolveRoundReplaceRange(items, normalizedRoundId); + if (range) { + items.splice(range.start, range.deleteCount, ...nextMessages); + return; + } + items.splice(resolveRoundInsertIndex(items, normalizedRoundId), 0, ...nextMessages); + } } /** @@ -196,7 +280,9 @@ class ChatTimeHistoryMapper { * 从聊天历史记录恢复时间线。 */ static fromHistoryRecords(records: ChatTimeHistoryRecord[]) { - return records.flatMap((record) => this.fromHistoryRecord(record)); + return normalizeVisibleHistoryRecords(records).flatMap((record) => + this.fromHistoryRecord(record), + ); } /** @@ -249,8 +335,15 @@ class ChatTimeHistoryMapper { const assistant = createAssistantItem(record.created, { id: record.id == null ? undefined : String(record.id), loading: record.loading, + messageKind: record.messageKind, + roundId: normalizeRoundId(record.roundId), + roundNo: record.roundNo, + selectedVariantIndex: record.selectedVariantIndex, senderName: record.senderName, + switchable: record.switchable, typing: record.typing, + variantCount: record.variantCount, + variantIndex: record.variantIndex, }); const tools: ChatTimeTimelineItem[] = []; @@ -267,7 +360,7 @@ class ChatTimeHistoryMapper { continue; } - const toolItem = createToolItemFromChain(rawChain, record.created); + const toolItem = createToolItemFromChain(rawChain, record.created, record); if (toolItem) { tools.push(toolItem); } @@ -316,6 +409,7 @@ class ChatTimeHistoryMapper { rawMessage, toolMetaMap, record.created, + record, ), ); } @@ -325,11 +419,84 @@ class ChatTimeHistoryMapper { } } +function normalizeVisibleHistoryRecords(records: ChatTimeHistoryRecord[]) { + const dedupedRecords = dedupeHistoryRecords(records); + const userSelectedVariantByRound = new Map(); + const assistantSelectedVariantByRound = new Map(); + const fallbackVariantByRound = new Map(); + for (const record of dedupedRecords) { + const roundId = normalizeRoundId(record.roundId); + if (!roundId) { + continue; + } + const selectedVariantIndex = normalizePositiveInteger( + record.selectedVariantIndex, + ); + if (selectedVariantIndex) { + if (isUserHistoryRecord(record)) { + userSelectedVariantByRound.set(roundId, selectedVariantIndex); + } else { + assistantSelectedVariantByRound.set(roundId, selectedVariantIndex); + } + } + const variantIndex = normalizePositiveInteger(record.variantIndex); + if (!isUserHistoryRecord(record) && variantIndex) { + fallbackVariantByRound.set(roundId, variantIndex); + } + } + return dedupedRecords.filter((record) => { + const roundId = normalizeRoundId(record.roundId); + if (!roundId || isUserHistoryRecord(record)) { + return true; + } + const variantIndex = normalizePositiveInteger(record.variantIndex); + if (!variantIndex) { + return true; + } + const selectedVariantIndex = + userSelectedVariantByRound.get(roundId) || + assistantSelectedVariantByRound.get(roundId) || + fallbackVariantByRound.get(roundId); + return !selectedVariantIndex || variantIndex === selectedVariantIndex; + }); +} + +function dedupeHistoryRecords(records: ChatTimeHistoryRecord[]) { + const seen = new Set(); + const result: ChatTimeHistoryRecord[] = []; + for (const record of records) { + const key = resolveHistoryRecordKey(record); + if (seen.has(key)) { + continue; + } + seen.add(key); + result.push(record); + } + return result; +} + +function resolveHistoryRecordKey(record: ChatTimeHistoryRecord) { + if (record.id != null) { + return `id:${String(record.id)}`; + } + return [ + 'fallback', + normalizeRoundId(record.roundId) || '', + normalizeRole(record.senderRole || record.role), + normalizePositiveInteger(record.variantIndex) || '', + normalizePlainText(record.contentText || record.content), + ].join(':'); +} + +function isUserHistoryRecord(record: ChatTimeHistoryRecord) { + return normalizeRole(record.senderRole || record.role) === 'user'; +} + function createAssistantItem( created?: number | string, - patch?: Partial, + patch?: Omit, 'roundId'> & ChatTimeRoundMeta, ): ChatTimeAssistantItem { - return { + const item: ChatTimeAssistantItem = { content: patch?.content || '', created: normalizeTimestamp(created), id: patch?.id || uuid(), @@ -340,6 +507,8 @@ function createAssistantItem( senderName: patch?.senderName, typing: patch?.typing, }; + applyRoundMeta(item, patch); + return item; } function createAssistantItemFromStructuredMessage( @@ -360,8 +529,15 @@ function createAssistantItemFromStructuredMessage( ? undefined : `${String(record.id)}-assistant-${assistantIndex}`, loading: false, + messageKind: record.messageKind, + roundId: normalizeRoundId(record.roundId), + roundNo: record.roundNo, + selectedVariantIndex: record.selectedVariantIndex, senderName: record.senderName, + switchable: record.switchable, typing: false, + variantCount: record.variantCount, + variantIndex: record.variantIndex, }); if (reasoning) { assistant.segments.push({ @@ -381,6 +557,7 @@ function createAssistantItemFromStructuredMessage( function createToolItemFromChain( rawChain: Record, created?: number | string, + record?: ChatTimeHistoryRecord, ) { const toolCallId = normalizePlainText(rawChain.id); const name = normalizePlainText(rawChain.name); @@ -393,10 +570,17 @@ function createToolItemFromChain( arguments: status === 'TOOL_CALL' ? argumentsValue : undefined, created, id: toolCallId || uuid(), + messageKind: record?.messageKind, name, + roundId: record?.roundId, + roundNo: record?.roundNo, result: status === 'TOOL_RESULT' ? argumentsValue : undefined, + selectedVariantIndex: record?.selectedVariantIndex, status, + switchable: record?.switchable, toolCallId, + variantCount: record?.variantCount, + variantIndex: record?.variantIndex, }); } @@ -404,6 +588,7 @@ function createToolItemFromStructuredMessage( rawMessage: Record, toolMetaMap: Map, created?: number | string, + record?: ChatTimeHistoryRecord, ) { const toolCallId = normalizePlainText( rawMessage.toolCallId ?? rawMessage.tool_call_id, @@ -414,10 +599,17 @@ function createToolItemFromStructuredMessage( arguments: toolMeta?.arguments, created, id: toolCallId || uuid(), + messageKind: record?.messageKind, name: toolMeta?.name, + roundId: record?.roundId, + roundNo: record?.roundNo, result, + selectedVariantIndex: record?.selectedVariantIndex, status: 'TOOL_RESULT', + switchable: record?.switchable, toolCallId, + variantCount: record?.variantCount, + variantIndex: record?.variantIndex, }); } @@ -429,12 +621,19 @@ function createToolItemFromTopLevelRecord(record: ChatTimeHistoryRecord) { return createToolItem({ created: record.created, id: record.id == null ? toolCallId || uuid() : String(record.id), + messageKind: record.messageKind, name: normalizePlainText(payload.name), + roundId: record.roundId, + roundNo: record.roundNo, result: normalizePayloadValue( payload.content ?? payload.result ?? record.contentText ?? record.content, ), + selectedVariantIndex: record.selectedVariantIndex, status: 'TOOL_RESULT', + switchable: record.switchable, toolCallId, + variantCount: record.variantCount, + variantIndex: record.variantIndex, }); } @@ -442,12 +641,19 @@ function createToolItem(payload: { arguments?: string; created?: number | string; id?: string; + messageKind?: string; name?: string; + roundId?: number | string; + roundNo?: number; result?: string; + selectedVariantIndex?: number; status: ChatTimeToolStatus; + switchable?: boolean; toolCallId?: string; + variantCount?: number; + variantIndex?: number; }): ChatTimeToolItem { - return { + const item: ChatTimeToolItem = { arguments: payload.arguments, content: payload.result || '', created: normalizeTimestamp(payload.created), @@ -459,10 +665,12 @@ function createToolItem(payload: { status: payload.status, toolCallId: payload.toolCallId || payload.id || uuid(), }; + applyRoundMeta(item, payload); + return item; } function createUserItem(record: ChatTimeHistoryRecord): ChatTimeTimelineItem { - return { + const item: ChatTimeTimelineItem = { content: normalizePlainText(record.contentText || record.content), created: normalizeTimestamp(record.created), id: record.id == null ? uuid() : String(record.id), @@ -472,6 +680,8 @@ function createUserItem(record: ChatTimeHistoryRecord): ChatTimeTimelineItem { senderName: record.senderName, typing: record.typing, }; + applyRoundMeta(item, record); + return item; } function appendAssistantText(item: ChatTimeAssistantItem, content: string) { @@ -507,14 +717,17 @@ function collectToolMeta( function ensureAssistantTail( items: ChatTimeTimelineItem[], created?: number | string, + meta?: ChatTimeRoundMeta, ) { const last = items[items.length - 1]; - if (isAssistantItem(last)) { + if (isAssistantItem(last) && isSameRoundVariant(last, meta)) { + applyRoundMeta(last, meta); return last; } const assistant = createAssistantItem(created, { loading: true, typing: true, + ...normalizeRoundMeta(meta), }); items.push(assistant); return assistant; @@ -525,38 +738,64 @@ function ensureToolItem( toolCallId?: string, created?: number | string, name?: string, + meta?: ChatTimeRoundMeta, ) { const normalizedToolCallId = normalizePlainText(toolCallId); - const found = findToolItem(items, normalizedToolCallId); + const found = findToolItem(items, normalizedToolCallId, meta); if (found) { if (name) { found.name = name; } + applyRoundMeta(found, meta); return found; } const toolItem = createToolItem({ created, id: normalizedToolCallId || uuid(), + messageKind: meta?.messageKind, name, + roundId: meta?.roundId, + roundNo: meta?.roundNo, + selectedVariantIndex: meta?.selectedVariantIndex, status: 'TOOL_CALL', + switchable: meta?.switchable, toolCallId: normalizedToolCallId, + variantCount: meta?.variantCount, + variantIndex: meta?.variantIndex, }); items.push(toolItem); return toolItem; } -function findToolItem(items: ChatTimeTimelineItem[], toolCallId?: string) { +function findToolItem( + items: ChatTimeTimelineItem[], + toolCallId?: string, + meta?: ChatTimeRoundMeta, +) { + const normalizedRoundId = normalizeRoundId(meta?.roundId); + const normalizedVariantIndex = normalizePositiveInteger(meta?.variantIndex); if (toolCallId) { for (let index = items.length - 1; index >= 0; index -= 1) { const item = items[index]; - if (isToolItem(item) && item.toolCallId === toolCallId) { + if ( + isToolItem(item) && + item.toolCallId === toolCallId && + matchesRoundVariant(item, normalizedRoundId, normalizedVariantIndex) + ) { return item; } } } for (let index = items.length - 1; index >= 0; index -= 1) { const item = items[index]; - if (isToolItem(item) && item.status === 'TOOL_CALL') { + if (!item) { + continue; + } + if ( + isToolItem(item) && + item.status === 'TOOL_CALL' && + matchesRoundVariant(item, normalizedRoundId, normalizedVariantIndex) + ) { return item; } } @@ -584,6 +823,182 @@ function isToolItem(item?: ChatTimeTimelineItem): item is ChatTimeToolItem { return item?.role === 'tool'; } +function findLastAssistant(items: ChatTimeTimelineItem[]) { + for (let index = items.length - 1; index >= 0; index -= 1) { + const item = items[index]; + if (isAssistantItem(item)) { + return item; + } + } + return undefined; +} + +function prepareRoundVariant( + items: ChatTimeTimelineItem[], + meta?: ChatTimeRoundMeta, +) { + const normalizedRoundId = normalizeRoundId(meta?.roundId); + const normalizedVariantIndex = normalizePositiveInteger(meta?.variantIndex); + if (!normalizedRoundId || !normalizedVariantIndex) { + return; + } + const assistant = items.find( + (item) => item.role === 'assistant' && item.roundId === normalizedRoundId, + ); + if (assistant?.variantIndex === normalizedVariantIndex) { + return; + } + for (let index = items.length - 1; index >= 0; index -= 1) { + const item = items[index]; + if (!item) { + continue; + } + if (item.roundId === normalizedRoundId && item.role !== 'user') { + items.splice(index, 1); + } + } +} + +function resolveRoundInsertIndex( + items: ChatTimeTimelineItem[], + roundId: string, +) { + const firstRoundItemIndex = items.findIndex( + (item) => item.roundId === roundId && item.role !== 'user', + ); + if (firstRoundItemIndex >= 0) { + return firstRoundItemIndex; + } + for (let index = items.length - 1; index >= 0; index -= 1) { + const item = items[index]; + if (!item) { + continue; + } + if (item.roundId === roundId && item.role === 'user') { + return index + 1; + } + } + return items.length; +} + +function resolveRoundReplaceRange( + items: ChatTimeTimelineItem[], + roundId: string, +) { + let start = -1; + let end = -1; + for (let index = 0; index < items.length; index += 1) { + const item = items[index]; + if (item?.roundId === roundId && item.role !== 'user') { + if (start < 0) { + start = index; + } + end = index; + } else if (start >= 0) { + break; + } + } + if (start < 0) { + return null; + } + return { + deleteCount: end - start + 1, + start, + }; +} + +function matchesRoundVariant( + item: ChatTimeTimelineItem, + roundId?: string, + variantIndex?: number, +) { + if (roundId && item.roundId !== roundId) { + return false; + } + if (variantIndex && item.variantIndex && item.variantIndex !== variantIndex) { + return false; + } + return true; +} + +function isSameRoundVariant( + item: ChatTimeTimelineItem, + meta?: ChatTimeRoundMeta, +) { + const normalizedRoundId = normalizeRoundId(meta?.roundId); + const normalizedVariantIndex = normalizePositiveInteger(meta?.variantIndex); + if (!normalizedRoundId || !normalizedVariantIndex) { + return true; + } + return ( + item.roundId === normalizedRoundId && + normalizePositiveInteger(item.variantIndex) === normalizedVariantIndex + ); +} + +function applyRoundMeta( + target: Partial, + source?: ChatTimeRoundMeta | null, +) { + if (!source) { + return; + } + const roundId = normalizeRoundId(source.roundId); + if (roundId) { + target.roundId = roundId; + } + const roundNo = normalizePositiveInteger(source.roundNo); + if (roundNo) { + target.roundNo = roundNo; + } + const variantIndex = normalizePositiveInteger(source.variantIndex); + if (variantIndex) { + target.variantIndex = variantIndex; + } + const variantCount = normalizePositiveInteger(source.variantCount); + if (variantCount) { + target.variantCount = variantCount; + } + const selectedVariantIndex = normalizePositiveInteger( + source.selectedVariantIndex, + ); + if (selectedVariantIndex) { + target.selectedVariantIndex = selectedVariantIndex; + } + if (typeof source.switchable === 'boolean') { + target.switchable = source.switchable; + } + const messageKind = normalizePlainText(source.messageKind).trim(); + if (messageKind) { + target.messageKind = messageKind; + } +} + +function normalizeRoundMeta(meta?: ChatTimeRoundMeta): ChatTimeRoundMeta { + return { + messageKind: meta?.messageKind, + roundId: normalizeRoundId(meta?.roundId), + roundNo: meta?.roundNo, + selectedVariantIndex: meta?.selectedVariantIndex, + switchable: meta?.switchable, + variantCount: meta?.variantCount, + variantIndex: meta?.variantIndex, + }; +} + +function normalizeRoundId(value: any) { + const normalized = normalizePlainText(value).trim(); + return normalized || undefined; +} + +function normalizePositiveInteger(value: any) { + if (value == null || value === '') { + return undefined; + } + const parsed = Number.parseInt(String(value), 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : undefined; +} + function normalizeAssistantText(value: any) { return normalizePlainText(value) .replace(/^Final Answer:\s*/i, '') diff --git a/easyflow-ui-admin/packages/utils/src/helpers/chat-variant-switch.ts b/easyflow-ui-admin/packages/utils/src/helpers/chat-variant-switch.ts new file mode 100644 index 0000000..1d9181b --- /dev/null +++ b/easyflow-ui-admin/packages/utils/src/helpers/chat-variant-switch.ts @@ -0,0 +1,187 @@ +type VariantRecord = { + selectedVariantIndex?: number | string; + variantIndex?: number | string; +}; + +interface ChatVariantSwitchControllerOptions { + mapRecords: (records: TRecord[]) => TItem[]; + onError?: (error: unknown) => void; + onStateChange?: () => void; + replaceRound: (items: TItem[], roundId: string, nextItems: TItem[]) => void; +} + +interface EnsureVariantsOptions { + fetchVariants: () => Promise; + roundId: number | string; + sessionId: number | string; +} + +interface SwitchVariantOptions + extends EnsureVariantsOptions { + items: TItem[]; + onLocalSwitch?: (record: TRecord) => void; + persistVariant: () => Promise; + targetVariantIndex: number; +} + +function variantCacheKey(sessionId: number | string, roundId: number | string) { + return `${String(sessionId)}:${String(roundId)}`; +} + +function normalizeVariantIndex(value: unknown) { + const parsed = Number.parseInt(String(value || ''), 10); + return Number.isFinite(parsed) && parsed > 0 ? parsed : 0; +} + +function markVariantSelected( + record: TRecord, + selectedVariantIndex: number, +): TRecord { + return { + ...record, + selectedVariantIndex, + }; +} + +function syncCachedSelection( + records: TRecord[], + selectedVariantIndex: number, + selectedRecord?: TRecord, +) { + return records.map((record) => { + const isSelected = + selectedRecord && + normalizeVariantIndex(record.variantIndex) === + normalizeVariantIndex(selectedRecord.variantIndex); + return markVariantSelected( + isSelected ? { ...record, ...selectedRecord } : record, + selectedVariantIndex, + ); + }); +} + +export function createChatVariantSwitchController< + TRecord extends VariantRecord, + TItem, +>(options: ChatVariantSwitchControllerOptions) { + const cache = new Map(); + const fetchTasks = new Map>(); + const switchingKeys = new Set(); + + function notifyStateChange() { + options.onStateChange?.(); + } + + async function ensureVariants(params: EnsureVariantsOptions) { + const key = variantCacheKey(params.sessionId, params.roundId); + const cached = cache.get(key); + if (cached) { + return cached; + } + const existingTask = fetchTasks.get(key); + if (existingTask) { + return existingTask; + } + const task = params + .fetchVariants() + .then((records) => { + cache.set(key, records); + return records; + }) + .finally(() => { + fetchTasks.delete(key); + }); + fetchTasks.set(key, task); + return task; + } + + function prefetchVariants(params: EnsureVariantsOptions) { + void ensureVariants(params).catch(() => { + // 预取失败不打断当前页面,用户点击时仍会再次拉取。 + }); + } + + function hasCachedVariant( + sessionId: number | string, + roundId: number | string, + variantIndex: number, + ) { + const records = cache.get(variantCacheKey(sessionId, roundId)); + return Boolean( + records?.some( + (record) => normalizeVariantIndex(record.variantIndex) === variantIndex, + ), + ); + } + + function isSwitching(sessionId?: number | string, roundId?: number | string) { + if (!sessionId || !roundId) { + return false; + } + return switchingKeys.has(variantCacheKey(sessionId, roundId)); + } + + async function switchVariant(params: SwitchVariantOptions) { + const key = variantCacheKey(params.sessionId, params.roundId); + if (switchingKeys.has(key)) { + return null; + } + switchingKeys.add(key); + notifyStateChange(); + const snapshot = [...params.items]; + try { + const records = await ensureVariants(params); + const target = records.find( + (record) => + normalizeVariantIndex(record.variantIndex) === params.targetVariantIndex, + ); + if (!target) { + throw new Error('目标答案版本不存在'); + } + const localTarget = markVariantSelected(target, params.targetVariantIndex); + const nextItems = options.mapRecords([localTarget]); + if (nextItems.length === 0) { + throw new Error('目标答案版本渲染失败'); + } + options.replaceRound( + params.items, + String(params.roundId), + nextItems, + ); + params.onLocalSwitch?.(localTarget); + const persistedRecord = await params.persistVariant(); + const selectedRecord = markVariantSelected( + persistedRecord || localTarget, + params.targetVariantIndex, + ); + cache.set( + key, + syncCachedSelection(records, params.targetVariantIndex, selectedRecord), + ); + return selectedRecord; + } catch (error) { + params.items.splice(0, params.items.length, ...snapshot); + options.onError?.(error); + return null; + } finally { + switchingKeys.delete(key); + notifyStateChange(); + } + } + + function cacheVariants( + sessionId: number | string, + roundId: number | string, + records: TRecord[], + ) { + cache.set(variantCacheKey(sessionId, roundId), records); + } + + return { + cacheVariants, + hasCachedVariant, + isSwitching, + prefetchVariants, + switchVariant, + }; +} diff --git a/easyflow-ui-admin/packages/utils/src/helpers/generate-routes-backend.ts b/easyflow-ui-admin/packages/utils/src/helpers/generate-routes-backend.ts index 0f6cac2..10ed526 100644 --- a/easyflow-ui-admin/packages/utils/src/helpers/generate-routes-backend.ts +++ b/easyflow-ui-admin/packages/utils/src/helpers/generate-routes-backend.ts @@ -59,6 +59,9 @@ function convertRoutes( const pageKey = normalizePath.endsWith('.vue') ? normalizePath : `${normalizePath}.vue`; + if (pageKey === '/ai/chat/index.vue' && route.meta) { + route.meta.fullPathKey = false; + } if (pageMap[pageKey]) { route.component = pageMap[pageKey]; } else { diff --git a/easyflow-ui-admin/packages/utils/src/helpers/index.ts b/easyflow-ui-admin/packages/utils/src/helpers/index.ts index ce140fc..01f9356 100644 --- a/easyflow-ui-admin/packages/utils/src/helpers/index.ts +++ b/easyflow-ui-admin/packages/utils/src/helpers/index.ts @@ -1,4 +1,5 @@ export * from './chat-time'; +export * from './chat-variant-switch'; export * from './find-menu-by-path'; export * from './generate-menus'; export * from './generate-routes-backend';