fix(ai): harden SSE stream error handling

- make chat SSE timeout configurable and default to 10 minutes

- stop upstream stream client when emitter send fails

- add full exception logging and frontend error notification on stream failures
This commit is contained in:
2026-02-25 18:55:20 +08:00
parent 58bd8b7ff1
commit 371b8cf891
4 changed files with 173 additions and 33 deletions

View File

@@ -1,26 +1,35 @@
package tech.easyflow.core.chat.protocol.sse; package tech.easyflow.core.chat.protocol.sse;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import tech.easyflow.common.util.StringUtil;
import tech.easyflow.common.util.SpringContextUtil; import tech.easyflow.common.util.SpringContextUtil;
import tech.easyflow.core.chat.protocol.ChatEnvelope; import tech.easyflow.core.chat.protocol.ChatEnvelope;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
public class ChatSseEmitter { public class ChatSseEmitter {
private static final long DEFAULT_TIMEOUT = Duration.ofMinutes(5).toMillis(); private static final Logger LOG = LoggerFactory.getLogger(ChatSseEmitter.class);
private static final long DEFAULT_TIMEOUT = Duration.ofMinutes(10).toMillis();
private static final String PROP_SSE_TIMEOUT_MS = "easyflow.chat.sse-timeout-ms";
private static final String PROP_SSE_TIMEOUT_MS_ALT = "easyflow.chat.sse.timeout-ms";
private final SseEmitter emitter; private final SseEmitter emitter;
private final AtomicBoolean closed = new AtomicBoolean(false);
public ChatSseEmitter() { public ChatSseEmitter() {
this(DEFAULT_TIMEOUT); this(resolveTimeoutMillis());
} }
public ChatSseEmitter(long timeoutMillis) { public ChatSseEmitter(long timeoutMillis) {
this.emitter = new SseEmitter(timeoutMillis); this.emitter = new SseEmitter(timeoutMillis);
registerLifecycleCallbacks();
} }
public SseEmitter getEmitter() { public SseEmitter getEmitter() {
@@ -28,42 +37,51 @@ public class ChatSseEmitter {
} }
/** 发送普通 ChatEnvelopeevent: message */ /** 发送普通 ChatEnvelopeevent: message */
public void send(ChatEnvelope<?> envelope) { public boolean send(ChatEnvelope<?> envelope) {
send("message", envelope); return send("message", envelope);
} }
/** 发送 error 事件 */ /** 发送 error 事件 */
public void sendError(ChatEnvelope<?> envelope) { public boolean sendError(ChatEnvelope<?> envelope) {
send("error", envelope); return send("error", envelope);
} }
/** 发送 done 事件并关闭 */ /** 发送 done 事件并关闭 */
public void sendDone(ChatEnvelope<?> envelope) { public boolean sendDone(ChatEnvelope<?> envelope) {
send("done", envelope); boolean sent = send("done", envelope);
complete(); complete();
return sent;
} }
/** 🔥 新增:发送并立即关闭 */ /** 🔥 新增:发送并立即关闭 */
public void sendAndClose(ChatEnvelope<?> envelope) { public boolean sendAndClose(ChatEnvelope<?> envelope) {
send("message", envelope); boolean sent = send("message", envelope);
if (!sent) {
return false;
}
ThreadPoolTaskExecutor threadPoolTaskExecutor = SpringContextUtil.getBean("sseThreadPool"); ThreadPoolTaskExecutor threadPoolTaskExecutor = SpringContextUtil.getBean("sseThreadPool");
threadPoolTaskExecutor.execute(() -> { threadPoolTaskExecutor.execute(() -> {
try { try {
Thread.sleep(500); Thread.sleep(500);
complete(); complete();
} catch (InterruptedException e) { } catch (InterruptedException e) {
throw new RuntimeException(e); Thread.currentThread().interrupt();
LOG.error("ChatSseEmitter sendAndClose interrupted, message={}, exception={}", e.getMessage(), e.toString(), e);
} }
}); });
return true;
} }
/** 通知前端保存该消息 */ /** 通知前端保存该消息 */
public void sendMessageNeedSave(ChatEnvelope<?> envelope) { public boolean sendMessageNeedSave(ChatEnvelope<?> envelope) {
send("needSaveMessage", envelope); return send("needSaveMessage", envelope);
} }
/** SSE 底层发送 */ /** SSE 底层发送 */
private void send(String event, ChatEnvelope<?> envelope) { private boolean send(String event, ChatEnvelope<?> envelope) {
if (closed.get()) {
return false;
}
try { try {
String json = JSON.toJSONString(envelope); String json = JSON.toJSONString(envelope);
emitter.send( emitter.send(
@@ -71,16 +89,80 @@ public class ChatSseEmitter {
.name(event) .name(event)
.data(json) .data(json)
); );
return true;
} catch (IllegalStateException e) {
closed.compareAndSet(false, true);
LOG.error("ChatSseEmitter send failed(event={}), message={}, exception={}", event, e.getMessage(), e.toString(), e);
return false;
} catch (IOException e) { } catch (IOException e) {
emitter.completeWithError(e); LOG.error("ChatSseEmitter send io failed(event={}), message={}, exception={}", event, e.getMessage(), e.toString(), e);
safeCompleteWithError(e);
return false;
} catch (Exception e) {
LOG.error("ChatSseEmitter send unexpected failed(event={}), message={}, exception={}", event, e.getMessage(), e.toString(), e);
safeCompleteWithError(e);
return false;
} }
} }
public void complete() { public void complete() {
emitter.complete(); if (closed.compareAndSet(false, true)) {
emitter.complete();
}
} }
public void completeWithError(Throwable ex) { public void completeWithError(Throwable ex) {
emitter.completeWithError(ex); if (ex == null) {
complete();
return;
}
safeCompleteWithError(ex);
}
public boolean isClosed() {
return closed.get();
}
private static long resolveTimeoutMillis() {
Long fromProp = SpringContextUtil.getProperty(PROP_SSE_TIMEOUT_MS, Long.class, null);
if (fromProp == null) {
fromProp = SpringContextUtil.getProperty(PROP_SSE_TIMEOUT_MS_ALT, Long.class, null);
}
if (fromProp != null && fromProp > 0) {
return fromProp;
}
String fromString = SpringContextUtil.getProperty(PROP_SSE_TIMEOUT_MS);
if (StringUtil.noText(fromString)) {
fromString = SpringContextUtil.getProperty(PROP_SSE_TIMEOUT_MS_ALT);
}
if (StringUtil.hasText(fromString)) {
try {
long parsed = Long.parseLong(fromString.trim());
if (parsed > 0) {
return parsed;
}
} catch (Exception e) {
LOG.error("Invalid sse timeout config: key={}, value={}, message={}, exception={}",
PROP_SSE_TIMEOUT_MS, fromString, e.getMessage(), e.toString(), e);
}
}
return DEFAULT_TIMEOUT;
}
private void registerLifecycleCallbacks() {
emitter.onCompletion(() -> closed.compareAndSet(false, true));
emitter.onTimeout(() -> closed.compareAndSet(false, true));
emitter.onError(ex -> closed.compareAndSet(false, true));
}
private void safeCompleteWithError(Throwable ex) {
if (closed.compareAndSet(false, true)) {
try {
emitter.completeWithError(ex);
} catch (Exception completeEx) {
LOG.error("ChatSseEmitter completeWithError failed, message={}, exception={}",
completeEx.getMessage(), completeEx.toString(), completeEx);
}
}
} }
} }

View File

@@ -8,6 +8,9 @@ import com.easyagents.core.model.chat.StreamResponseListener;
import com.easyagents.core.model.chat.response.AiMessageResponse; import com.easyagents.core.model.chat.response.AiMessageResponse;
import com.easyagents.core.model.client.StreamContext; import com.easyagents.core.model.client.StreamContext;
import com.easyagents.core.prompt.MemoryPrompt; import com.easyagents.core.prompt.MemoryPrompt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.easyflow.common.util.StringUtil;
import tech.easyflow.core.chat.protocol.ChatDomain; import tech.easyflow.core.chat.protocol.ChatDomain;
import tech.easyflow.core.chat.protocol.ChatEnvelope; import tech.easyflow.core.chat.protocol.ChatEnvelope;
import tech.easyflow.core.chat.protocol.ChatType; import tech.easyflow.core.chat.protocol.ChatType;
@@ -15,7 +18,6 @@ import tech.easyflow.core.chat.protocol.MessageRole;
import tech.easyflow.core.chat.protocol.payload.ErrorPayload; import tech.easyflow.core.chat.protocol.payload.ErrorPayload;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import java.io.IOException;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -23,6 +25,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
public class ChatStreamListener implements StreamResponseListener { public class ChatStreamListener implements StreamResponseListener {
private static final Logger LOG = LoggerFactory.getLogger(ChatStreamListener.class);
private final String conversationId; private final String conversationId;
private final ChatModel chatModel; private final ChatModel chatModel;
private final MemoryPrompt memoryPrompt; private final MemoryPrompt memoryPrompt;
@@ -51,6 +55,10 @@ public class ChatStreamListener implements StreamResponseListener {
@Override @Override
public void onMessage(StreamContext context, AiMessageResponse aiMessageResponse) { public void onMessage(StreamContext context, AiMessageResponse aiMessageResponse) {
try { try {
if (completed.get() || sseEmitter.isClosed()) {
stopStreamClient(context, "emitter_closed_before_message", null);
return;
}
AiMessage aiMessage = aiMessageResponse.getMessage(); AiMessage aiMessage = aiMessageResponse.getMessage();
if (aiMessage == null) { if (aiMessage == null) {
return; return;
@@ -81,7 +89,12 @@ public class ChatStreamListener implements StreamResponseListener {
} }
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); LOG.error("Chat stream onMessage failed, conversationId={}, message={}, exception={}",
conversationId, e.getMessage(), e.toString(), e);
if (completed.compareAndSet(false, true)) {
sendSystemError(sseEmitter, "流式响应异常,请重试", e);
}
stopStreamClient(context, "on_message_exception", e);
} }
} }
@@ -89,15 +102,17 @@ public class ChatStreamListener implements StreamResponseListener {
public void onStop(StreamContext context) { public void onStop(StreamContext context) {
// 仅当canStop为true最后一次无后续工具调用的响应执行业务逻辑 // 仅当canStop为true最后一次无后续工具调用的响应执行业务逻辑
if (this.canStop && completed.compareAndSet(false, true)) { if (this.canStop && completed.compareAndSet(false, true)) {
System.out.println("onStop");
if (context.getThrowable() != null) { if (context.getThrowable() != null) {
sendSystemError(sseEmitter, context.getThrowable().getMessage()); sendSystemError(sseEmitter, context.getThrowable().getMessage(), context.getThrowable());
return; return;
} }
memoryPrompt.addMessage(context.getFullMessage()); memoryPrompt.addMessage(context.getFullMessage());
ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>(); ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>();
chatEnvelope.setDomain(ChatDomain.SYSTEM); chatEnvelope.setDomain(ChatDomain.SYSTEM);
sseEmitter.sendDone(chatEnvelope); boolean doneSent = sseEmitter.sendDone(chatEnvelope);
if (!doneSent) {
LOG.warn("sendDone skipped because emitter is closed, conversationId={}", conversationId);
}
StreamResponseListener.super.onStop(context); StreamResponseListener.super.onStop(context);
} }
@@ -105,13 +120,17 @@ public class ChatStreamListener implements StreamResponseListener {
@Override @Override
public void onFailure(StreamContext context, Throwable throwable) { public void onFailure(StreamContext context, Throwable throwable) {
if (throwable != null && completed.compareAndSet(false, true)) { if (throwable != null) {
throwable.printStackTrace(); LOG.error("Chat stream onFailure, conversationId={}, message={}, exception={}",
sendSystemError(sseEmitter, throwable.getMessage()); conversationId, throwable.getMessage(), throwable.toString(), throwable);
} }
if (throwable != null && completed.compareAndSet(false, true)) {
sendSystemError(sseEmitter, throwable.getMessage(), throwable);
}
stopStreamClient(context, "on_failure", throwable);
} }
private void sendChatEnvelope(ChatSseEmitter sseEmitter, String deltaContent, ChatType chatType) throws IOException { private void sendChatEnvelope(ChatSseEmitter sseEmitter, String deltaContent, ChatType chatType) {
if (deltaContent == null || deltaContent.isEmpty()) { if (deltaContent == null || deltaContent.isEmpty()) {
return; return;
} }
@@ -126,21 +145,49 @@ public class ChatStreamListener implements StreamResponseListener {
deltaMap.put("delta", deltaContent); deltaMap.put("delta", deltaContent);
chatEnvelope.setPayload(deltaMap); chatEnvelope.setPayload(deltaMap);
sseEmitter.send(chatEnvelope); boolean sent = sseEmitter.send(chatEnvelope);
if (!sent) {
throw new IllegalStateException("SSE emitter has already completed while sending chat envelope");
}
} }
public void sendSystemError(ChatSseEmitter sseEmitter, public void sendSystemError(ChatSseEmitter sseEmitter,
String message) { String message,
Throwable throwable) {
String errorMessage = StringUtil.hasText(message) ? message : "系统繁忙,请稍后重试";
if (throwable != null) {
LOG.error("sendSystemError, conversationId={}, message={}, exception={}",
conversationId, throwable.getMessage(), throwable.toString(), throwable);
}
ChatEnvelope<ErrorPayload> envelope = new ChatEnvelope<>(); ChatEnvelope<ErrorPayload> envelope = new ChatEnvelope<>();
ErrorPayload payload = new ErrorPayload(); ErrorPayload payload = new ErrorPayload();
payload.setMessage(message); payload.setMessage(errorMessage);
payload.setCode("SYSTEM_ERROR"); payload.setCode("SYSTEM_ERROR");
payload.setRetryable(false); payload.setRetryable(false);
envelope.setPayload(payload); envelope.setPayload(payload);
envelope.setDomain(ChatDomain.SYSTEM); envelope.setDomain(ChatDomain.SYSTEM);
envelope.setType(ChatType.ERROR); envelope.setType(ChatType.ERROR);
sseEmitter.sendError(envelope); boolean sent = sseEmitter.sendError(envelope);
if (!sent) {
LOG.warn("sendSystemError skipped because emitter is closed, conversationId={}", conversationId);
}
sseEmitter.complete(); sseEmitter.complete();
} }
private void stopStreamClient(StreamContext context, String reason, Throwable source) {
try {
if (context != null && context.getClient() != null) {
context.getClient().stop();
LOG.warn("Stopped stream client, conversationId={}, reason={}", conversationId, reason);
}
} catch (Exception stopEx) {
LOG.error("Stop stream client failed, conversationId={}, reason={}, message={}, exception={}",
conversationId, reason, stopEx.getMessage(), stopEx.toString(), stopEx);
if (source != null) {
LOG.error("Original stream failure, conversationId={}, message={}, exception={}",
conversationId, source.getMessage(), source.toString(), source);
}
}
}
} }

View File

@@ -4,7 +4,8 @@ import com.easyagents.core.memory.DefaultChatMemory;
import com.easyagents.core.message.*; import com.easyagents.core.message.*;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature; import com.alibaba.fastjson.serializer.SerializerFeature;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.easyflow.ai.entity.BotMessage; import tech.easyflow.ai.entity.BotMessage;
import tech.easyflow.core.chat.protocol.ChatDomain; import tech.easyflow.core.chat.protocol.ChatDomain;
import tech.easyflow.core.chat.protocol.ChatEnvelope; import tech.easyflow.core.chat.protocol.ChatEnvelope;
@@ -12,7 +13,6 @@ import tech.easyflow.core.chat.protocol.ChatType;
import tech.easyflow.core.chat.protocol.MessageRole; import tech.easyflow.core.chat.protocol.MessageRole;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import java.io.IOException;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@@ -23,6 +23,8 @@ import java.util.stream.Collectors;
public class DefaultBotMessageMemory extends DefaultChatMemory { public class DefaultBotMessageMemory extends DefaultChatMemory {
private static final Logger LOG = LoggerFactory.getLogger(DefaultBotMessageMemory.class);
private final ChatSseEmitter sseEmitter; private final ChatSseEmitter sseEmitter;
private final List<Map<String, String>> messages; private final List<Map<String, String>> messages;
@@ -72,7 +74,13 @@ public class DefaultBotMessageMemory extends DefaultChatMemory {
if (dbMessage.getRole().equals(MessageRole.USER.getValue())) { if (dbMessage.getRole().equals(MessageRole.USER.getValue())) {
messages.remove(messages.size() - 1); messages.remove(messages.size() - 1);
} }
sseEmitter.sendMessageNeedSave(chatEnvelope); boolean sent = sseEmitter.sendMessageNeedSave(chatEnvelope);
if (!sent) {
IllegalStateException e = new IllegalStateException("SSE emitter has already completed when sending needSaveMessage");
LOG.error("sendMessageNeedSave failed, role={}, message={}, exception={}",
dbMessage.getRole(), e.getMessage(), e.toString(), e);
throw e;
}
messages.add(res); messages.add(res);
} }

View File

@@ -70,6 +70,9 @@ spring:
enabled: true enabled: true
easyflow: easyflow:
chat:
# SSE 超时时间(毫秒),默认 10 分钟,可按需调整
sse-timeout-ms: 600000
# 语音播放、识别服务(阿里云) # 语音播放、识别服务(阿里云)
audio: audio:
type: aliAudioService type: aliAudioService