feat: 增强多实例分布式部署兼容
- 增加定时任务分布式锁并覆盖 chatlog、文档导入和 Agent HITL 过期扫描 - 增强 Redis MQ 多实例 consumer 标识、pending reclaim 和单条处理能力 - 增加文档导入状态 Redis 广播和 Agent HITL 跨节点路由确认
This commit is contained in:
@@ -39,7 +39,23 @@
|
||||
<artifactId>fastjson</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-aop</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.12.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
package tech.easyflow.common.cache;
|
||||
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
import java.lang.annotation.Target;
|
||||
|
||||
/**
|
||||
* Spring 定时任务 Redis 分布式锁。
|
||||
*/
|
||||
@Target(ElementType.METHOD)
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
public @interface DistributedScheduledLock {
|
||||
|
||||
/**
|
||||
* 获取锁使用的 Redis key。
|
||||
*
|
||||
* @return Redis 锁 key
|
||||
*/
|
||||
String key();
|
||||
|
||||
/**
|
||||
* 等待锁的秒数。
|
||||
*
|
||||
* @return 等待锁的秒数
|
||||
*/
|
||||
long waitSeconds() default 0L;
|
||||
|
||||
/**
|
||||
* 锁租约秒数。
|
||||
*
|
||||
* @return 锁租约秒数
|
||||
*/
|
||||
long leaseSeconds() default 300L;
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package tech.easyflow.common.cache;
|
||||
|
||||
import org.aspectj.lang.ProceedingJoinPoint;
|
||||
import org.aspectj.lang.annotation.Around;
|
||||
import org.aspectj.lang.annotation.Aspect;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import jakarta.annotation.PreDestroy;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.ScheduledFuture;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* 定时任务分布式锁切面。
|
||||
*/
|
||||
@Aspect
|
||||
@Component
|
||||
public class DistributedScheduledLockAspect {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(DistributedScheduledLockAspect.class);
|
||||
|
||||
private final RedisLockExecutor redisLockExecutor;
|
||||
private final ScheduledExecutorService renewExecutor;
|
||||
|
||||
/**
|
||||
* 创建定时任务分布式锁切面。
|
||||
*
|
||||
* @param redisLockExecutor Redis 分布式锁执行器
|
||||
*/
|
||||
public DistributedScheduledLockAspect(RedisLockExecutor redisLockExecutor) {
|
||||
this.redisLockExecutor = redisLockExecutor;
|
||||
this.renewExecutor = Executors.newScheduledThreadPool(
|
||||
1,
|
||||
new DistributedScheduledLockThreadFactory()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 拦截带分布式调度锁的定时任务。
|
||||
*
|
||||
* @param joinPoint 切点
|
||||
* @param lock 锁注解
|
||||
* @return 原方法返回值;未抢到锁时返回 null
|
||||
* @throws Throwable 原方法执行异常或 Redis 访问异常
|
||||
*/
|
||||
@Around("@annotation(lock)")
|
||||
public Object around(ProceedingJoinPoint joinPoint, DistributedScheduledLock lock) throws Throwable {
|
||||
Duration waitTimeout = Duration.ofSeconds(Math.max(lock.waitSeconds(), 0L));
|
||||
Duration leaseTimeout = Duration.ofSeconds(Math.max(lock.leaseSeconds(), 1L));
|
||||
RedisLockExecutor.LockHandle handle = redisLockExecutor.tryAcquire(lock.key(), waitTimeout, leaseTimeout);
|
||||
if (handle == null) {
|
||||
LOG.info("定时任务分布式锁已被其他实例持有,跳过本轮执行: lockKey={}, method={}",
|
||||
lock.key(), joinPoint.getSignature().toShortString());
|
||||
return null;
|
||||
}
|
||||
ScheduledFuture<?> renewTask = scheduleRenew(lock.key(), handle, leaseTimeout);
|
||||
try {
|
||||
return joinPoint.proceed();
|
||||
} finally {
|
||||
renewTask.cancel(false);
|
||||
handle.release();
|
||||
}
|
||||
}
|
||||
|
||||
private ScheduledFuture<?> scheduleRenew(String lockKey,
|
||||
RedisLockExecutor.LockHandle handle,
|
||||
Duration leaseTimeout) {
|
||||
long renewIntervalMillis = Math.max(leaseTimeout.toMillis() / 3L, 1000L);
|
||||
return renewExecutor.scheduleWithFixedDelay(() -> {
|
||||
if (!handle.renew()) {
|
||||
LOG.warn("定时任务分布式锁续期失败: lockKey={}", lockKey);
|
||||
}
|
||||
}, renewIntervalMillis, renewIntervalMillis, TimeUnit.MILLISECONDS);
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭调度锁续期线程池。
|
||||
*/
|
||||
@PreDestroy
|
||||
public void destroy() {
|
||||
renewExecutor.shutdownNow();
|
||||
}
|
||||
|
||||
/**
|
||||
* 调度锁续期线程工厂。
|
||||
*/
|
||||
private static final class DistributedScheduledLockThreadFactory implements ThreadFactory {
|
||||
|
||||
private final AtomicInteger index = new AtomicInteger(1);
|
||||
|
||||
/**
|
||||
* 创建续期线程。
|
||||
*
|
||||
* @param runnable 线程任务
|
||||
* @return 续期线程
|
||||
*/
|
||||
@Override
|
||||
public Thread newThread(Runnable runnable) {
|
||||
Thread thread = new Thread(runnable);
|
||||
thread.setName("distributed-scheduled-lock-renew-" + index.getAndIncrement());
|
||||
thread.setDaemon(true);
|
||||
return thread;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,9 @@ import java.util.Collections;
|
||||
import java.util.UUID;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* Redis 分布式锁执行器。
|
||||
*/
|
||||
@Component
|
||||
public class RedisLockExecutor {
|
||||
|
||||
@@ -42,6 +45,14 @@ public class RedisLockExecutor {
|
||||
@Autowired
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
/**
|
||||
* 在分布式锁保护下执行无返回任务。
|
||||
*
|
||||
* @param lockKey 锁 key
|
||||
* @param waitTimeout 等待锁的最大时间
|
||||
* @param leaseTimeout 锁租约时间
|
||||
* @param task 业务任务
|
||||
*/
|
||||
public void executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Runnable task) {
|
||||
executeWithLock(lockKey, waitTimeout, leaseTimeout, () -> {
|
||||
task.run();
|
||||
@@ -49,6 +60,16 @@ public class RedisLockExecutor {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 在分布式锁保护下执行有返回任务。
|
||||
*
|
||||
* @param lockKey 锁 key
|
||||
* @param waitTimeout 等待锁的最大时间
|
||||
* @param leaseTimeout 锁租约时间
|
||||
* @param task 业务任务
|
||||
* @param <T> 返回类型
|
||||
* @return 任务返回值
|
||||
*/
|
||||
public <T> T executeWithLock(String lockKey, Duration waitTimeout, Duration leaseTimeout, Supplier<T> task) {
|
||||
LockHandle handle = acquire(lockKey, waitTimeout, leaseTimeout);
|
||||
try {
|
||||
@@ -70,24 +91,46 @@ public class RedisLockExecutor {
|
||||
* @return 锁句柄
|
||||
*/
|
||||
public LockHandle acquire(String lockKey, Duration waitTimeout, Duration leaseTimeout) {
|
||||
LockHandle handle = tryAcquire(lockKey, waitTimeout, leaseTimeout);
|
||||
if (handle == null) {
|
||||
throw new IllegalStateException("获取分布式锁失败,请稍后重试,lockKey=" + lockKey);
|
||||
}
|
||||
return handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试获取显式释放的分布式锁句柄。
|
||||
*
|
||||
* <p>返回 {@code null} 表示锁当前被其他节点持有。Redis 访问失败或等待过程被中断仍会抛出异常,
|
||||
* 调用方可据此区分“正常跳过”和“基础设施异常”。</p>
|
||||
*
|
||||
* @param lockKey 锁 key
|
||||
* @param waitTimeout 等待时间
|
||||
* @param leaseTimeout 租约时间
|
||||
* @return 获取成功时返回锁句柄,否则返回 null
|
||||
*/
|
||||
public LockHandle tryAcquire(String lockKey, Duration waitTimeout, Duration leaseTimeout) {
|
||||
String lockValue = UUID.randomUUID().toString();
|
||||
boolean acquired = false;
|
||||
long deadline = System.nanoTime() + waitTimeout.toNanos();
|
||||
try {
|
||||
while (System.nanoTime() <= deadline) {
|
||||
do {
|
||||
Boolean success = stringRedisTemplate.opsForValue().setIfAbsent(lockKey, lockValue, leaseTimeout);
|
||||
if (Boolean.TRUE.equals(success)) {
|
||||
acquired = true;
|
||||
break;
|
||||
}
|
||||
Thread.sleep(RETRY_INTERVAL_MILLIS);
|
||||
if (System.nanoTime() >= deadline) {
|
||||
break;
|
||||
}
|
||||
Thread.sleep(RETRY_INTERVAL_MILLIS);
|
||||
} while (System.nanoTime() <= deadline);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("等待分布式锁被中断,lockKey=" + lockKey, e);
|
||||
}
|
||||
if (!acquired) {
|
||||
throw new IllegalStateException("获取分布式锁失败,请稍后重试,lockKey=" + lockKey);
|
||||
return null;
|
||||
}
|
||||
return new LockHandle(lockKey, lockValue, leaseTimeout);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package tech.easyflow.common.cache;
|
||||
|
||||
import org.aspectj.lang.ProceedingJoinPoint;
|
||||
import org.aspectj.lang.Signature;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
import org.springframework.data.redis.core.script.RedisScript;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* {@link DistributedScheduledLockAspect} 回归测试。
|
||||
*/
|
||||
public class DistributedScheduledLockAspectTest {
|
||||
|
||||
/**
|
||||
* 验证未抢到调度锁时跳过原方法。
|
||||
*
|
||||
* @throws Throwable 切面执行异常
|
||||
*/
|
||||
@Test
|
||||
public void aroundShouldSkipTaskWhenLockIsHeld() throws Throwable {
|
||||
RedisLockExecutor executor = createExecutor(false);
|
||||
DistributedScheduledLockAspect aspect = new DistributedScheduledLockAspect(executor);
|
||||
AtomicInteger proceedCount = new AtomicInteger();
|
||||
|
||||
Object result = aspect.around(
|
||||
mockJoinPoint(proceedCount),
|
||||
annotatedMethod("lockedTask").getAnnotation(DistributedScheduledLock.class)
|
||||
);
|
||||
|
||||
Assert.assertNull(result);
|
||||
Assert.assertEquals(0, proceedCount.get());
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证抢到调度锁时执行原方法并释放锁。
|
||||
*
|
||||
* @throws Throwable 切面执行异常
|
||||
*/
|
||||
@Test
|
||||
public void aroundShouldProceedAndReleaseWhenLockAcquired() throws Throwable {
|
||||
RedisLockExecutor executor = createExecutor(true);
|
||||
DistributedScheduledLockAspect aspect = new DistributedScheduledLockAspect(executor);
|
||||
AtomicInteger proceedCount = new AtomicInteger();
|
||||
|
||||
Object result = aspect.around(
|
||||
mockJoinPoint(proceedCount),
|
||||
annotatedMethod("lockedTask").getAnnotation(DistributedScheduledLock.class)
|
||||
);
|
||||
|
||||
Assert.assertEquals("ok", result);
|
||||
Assert.assertEquals(1, proceedCount.get());
|
||||
}
|
||||
|
||||
@DistributedScheduledLock(key = "easyflow:test:scheduled", leaseSeconds = 30L)
|
||||
private void lockedTask() {
|
||||
}
|
||||
|
||||
private Method annotatedMethod(String methodName) throws NoSuchMethodException {
|
||||
Method method = DistributedScheduledLockAspectTest.class.getDeclaredMethod(methodName);
|
||||
method.setAccessible(true);
|
||||
return method;
|
||||
}
|
||||
|
||||
private ProceedingJoinPoint mockJoinPoint(AtomicInteger proceedCount) throws Throwable {
|
||||
ProceedingJoinPoint joinPoint = Mockito.mock(ProceedingJoinPoint.class);
|
||||
Signature signature = Mockito.mock(Signature.class);
|
||||
Mockito.when(signature.toShortString()).thenReturn("lockedTask()");
|
||||
Mockito.when(joinPoint.getSignature()).thenReturn(signature);
|
||||
Mockito.when(joinPoint.proceed()).thenAnswer(invocation -> {
|
||||
proceedCount.incrementAndGet();
|
||||
return "ok";
|
||||
});
|
||||
return joinPoint;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private RedisLockExecutor createExecutor(boolean acquired) throws Exception {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(valueOperations.setIfAbsent(
|
||||
ArgumentMatchers.anyString(),
|
||||
ArgumentMatchers.anyString(),
|
||||
ArgumentMatchers.any(Duration.class)
|
||||
)).thenReturn(acquired);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
Mockito.when(redisTemplate.execute(
|
||||
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||
ArgumentMatchers.<List<String>>any(),
|
||||
ArgumentMatchers.<Object[]>any()
|
||||
)).thenReturn(1L);
|
||||
|
||||
RedisLockExecutor executor = new RedisLockExecutor();
|
||||
Field field = RedisLockExecutor.class.getDeclaredField("stringRedisTemplate");
|
||||
field.setAccessible(true);
|
||||
field.set(executor, redisTemplate);
|
||||
return executor;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package tech.easyflow.common.cache;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
import org.springframework.data.redis.core.script.RedisScript;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link RedisLockExecutor} 回归测试。
|
||||
*/
|
||||
public class RedisLockExecutorTest {
|
||||
|
||||
/**
|
||||
* 验证锁被占用时返回 null,便于调度任务跳过本轮执行。
|
||||
*
|
||||
* @throws Exception 反射注入异常
|
||||
*/
|
||||
@Test
|
||||
public void tryAcquireShouldReturnNullWhenLockIsHeld() throws Exception {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
ValueOperations<String, String> valueOperations = mockValueOperations(false);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
|
||||
RedisLockExecutor executor = new RedisLockExecutor();
|
||||
setRedisTemplate(executor, redisTemplate);
|
||||
|
||||
RedisLockExecutor.LockHandle handle = executor.tryAcquire(
|
||||
"easyflow:test:lock",
|
||||
Duration.ZERO,
|
||||
Duration.ofSeconds(30)
|
||||
);
|
||||
|
||||
Assert.assertNull(handle);
|
||||
Mockito.verify(valueOperations).setIfAbsent(
|
||||
ArgumentMatchers.eq("easyflow:test:lock"),
|
||||
ArgumentMatchers.anyString(),
|
||||
ArgumentMatchers.eq(Duration.ofSeconds(30))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证锁获取成功后释放会执行 owner token 校验脚本。
|
||||
*
|
||||
* @throws Exception 反射注入异常
|
||||
*/
|
||||
@Test
|
||||
public void acquiredHandleShouldReleaseLockWithOwnerToken() throws Exception {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
ValueOperations<String, String> valueOperations = mockValueOperations(true);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
Mockito.when(redisTemplate.execute(
|
||||
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||
ArgumentMatchers.<List<String>>any(),
|
||||
ArgumentMatchers.<Object[]>any()
|
||||
)).thenReturn(1L);
|
||||
|
||||
RedisLockExecutor executor = new RedisLockExecutor();
|
||||
setRedisTemplate(executor, redisTemplate);
|
||||
|
||||
RedisLockExecutor.LockHandle handle = executor.tryAcquire(
|
||||
"easyflow:test:lock",
|
||||
Duration.ZERO,
|
||||
Duration.ofSeconds(30)
|
||||
);
|
||||
|
||||
Assert.assertNotNull(handle);
|
||||
handle.release();
|
||||
Mockito.verify(redisTemplate).execute(
|
||||
ArgumentMatchers.<RedisScript<Long>>any(),
|
||||
ArgumentMatchers.eq(List.of("easyflow:test:lock")),
|
||||
ArgumentMatchers.<Object[]>any()
|
||||
);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private ValueOperations<String, String> mockValueOperations(boolean acquired) {
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(valueOperations.setIfAbsent(
|
||||
ArgumentMatchers.anyString(),
|
||||
ArgumentMatchers.anyString(),
|
||||
ArgumentMatchers.any(Duration.class)
|
||||
)).thenReturn(acquired);
|
||||
return valueOperations;
|
||||
}
|
||||
|
||||
private void setRedisTemplate(RedisLockExecutor executor, StringRedisTemplate redisTemplate) throws Exception {
|
||||
Field field = RedisLockExecutor.class.getDeclaredField("stringRedisTemplate");
|
||||
field.setAccessible(true);
|
||||
field.set(executor, redisTemplate);
|
||||
}
|
||||
}
|
||||
@@ -27,5 +27,17 @@
|
||||
<artifactId>commons-pool2</artifactId>
|
||||
<version>2.11.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.12.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package tech.easyflow.common.mq.config;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* EasyFlow MQ 配置。
|
||||
*/
|
||||
@ConfigurationProperties(prefix = "easyflow.mq")
|
||||
public class MQProperties {
|
||||
|
||||
@@ -35,6 +39,7 @@ public class MQProperties {
|
||||
|
||||
private int database = 1;
|
||||
private String streamPrefix = "easyflow:mq";
|
||||
private String consumerInstanceId = defaultConsumerInstanceId();
|
||||
private int chatPersistShardCount = 4;
|
||||
private int consumerBatchSize = 200;
|
||||
private Duration consumerBlockTimeout = Duration.ofMillis(2000);
|
||||
@@ -59,6 +64,26 @@ public class MQProperties {
|
||||
this.streamPrefix = streamPrefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Redis Stream 消费实例 ID。
|
||||
*
|
||||
* @return 消费实例 ID
|
||||
*/
|
||||
public String getConsumerInstanceId() {
|
||||
return consumerInstanceId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 Redis Stream 消费实例 ID。
|
||||
*
|
||||
* @param consumerInstanceId 消费实例 ID
|
||||
*/
|
||||
public void setConsumerInstanceId(String consumerInstanceId) {
|
||||
this.consumerInstanceId = StringUtils.hasText(consumerInstanceId)
|
||||
? consumerInstanceId.trim()
|
||||
: defaultConsumerInstanceId();
|
||||
}
|
||||
|
||||
public int getChatPersistShardCount() {
|
||||
return chatPersistShardCount;
|
||||
}
|
||||
@@ -191,5 +216,13 @@ public class MQProperties {
|
||||
this.minIdle = minIdle;
|
||||
}
|
||||
}
|
||||
|
||||
private static String defaultConsumerInstanceId() {
|
||||
String hostName = System.getenv("HOSTNAME");
|
||||
if (StringUtils.hasText(hostName)) {
|
||||
return hostName.trim();
|
||||
}
|
||||
return java.util.UUID.randomUUID().toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ public class MQSubscription {
|
||||
private String topic;
|
||||
private String consumerGroup;
|
||||
private int shardCount;
|
||||
private boolean batchEnabled = true;
|
||||
|
||||
public String getTopic() {
|
||||
return topic;
|
||||
@@ -29,4 +30,22 @@ public class MQSubscription {
|
||||
public void setShardCount(int shardCount) {
|
||||
this.shardCount = shardCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否启用批量消费。
|
||||
*
|
||||
* @return true 表示启用批量消费
|
||||
*/
|
||||
public boolean isBatchEnabled() {
|
||||
return batchEnabled;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置是否启用批量消费。
|
||||
*
|
||||
* @param batchEnabled 是否启用批量消费
|
||||
*/
|
||||
public void setBatchEnabled(boolean batchEnabled) {
|
||||
this.batchEnabled = batchEnabled;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.concurrent.ArrayBlockingQueue;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
@@ -39,6 +40,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||
public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifecycle {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(RedisMQConsumerContainer.class);
|
||||
private static final Pattern UNSAFE_CONSUMER_NAME_CHARS = Pattern.compile("[^A-Za-z0-9_.-]");
|
||||
|
||||
private final RedisConnectionFactory redisConnectionFactory;
|
||||
private final StringRedisTemplate stringRedisTemplate;
|
||||
@@ -154,13 +156,24 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
|
||||
private void consumeLoop(MQConsumerHandler handler, MQSubscription subscription, int shard) {
|
||||
String streamKey = keySupport.streamKey(subscription.getTopic(), shard);
|
||||
String consumerName = subscription.getConsumerGroup() + "-" + shard;
|
||||
String consumerName = buildConsumerName(subscription.getConsumerGroup(), shard);
|
||||
ensureConsumerGroup(streamKey, subscription.getConsumerGroup());
|
||||
LOG.info("MQ 消费循环已启动: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
||||
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, handler.getClass().getSimpleName());
|
||||
while (running) {
|
||||
try {
|
||||
List<MapRecord<String, Object, Object>> pendingRecords =
|
||||
reclaimPending(streamKey, subscription.getConsumerGroup(), consumerName);
|
||||
if (!pendingRecords.isEmpty()) {
|
||||
List<MQMessage> pendingMessages = toMessages(streamKey, pendingRecords);
|
||||
if (!pendingMessages.isEmpty()) {
|
||||
LOG.info("MQ 收到重领 pending 消息批次: topic={}, group={}, shard={}, consumer={}, streamKey={}, count={}",
|
||||
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName,
|
||||
streamKey, pendingMessages.size());
|
||||
handleMessages(handler, subscription, streamKey, subscription.getConsumerGroup(), pendingMessages);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
List<MapRecord<String, Object, Object>> records = stringRedisTemplate.opsForStream().read(
|
||||
Consumer.from(subscription.getConsumerGroup(), consumerName),
|
||||
StreamReadOptions.empty()
|
||||
@@ -177,7 +190,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
LOG.info("MQ 收到消息批次: topic={}, group={}, shard={}, consumer={}, streamKey={}, count={}",
|
||||
subscription.getTopic(), subscription.getConsumerGroup(), shard, consumerName, streamKey, messages.size());
|
||||
handleMessages(handler, streamKey, subscription.getConsumerGroup(), messages);
|
||||
handleMessages(handler, subscription, streamKey, subscription.getConsumerGroup(), messages);
|
||||
} catch (Exception exception) {
|
||||
LOG.error("MQ 消费循环异常: topic={}, group={}, shard={}, consumer={}, streamKey={}, handler={}",
|
||||
subscription.getTopic(),
|
||||
@@ -192,7 +205,20 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
}
|
||||
|
||||
private void reclaimPending(String streamKey, String group, String consumerName) {
|
||||
/**
|
||||
* 构建 Redis Stream consumer name。
|
||||
*
|
||||
* @param consumerGroup 消费组
|
||||
* @param shard 分片序号
|
||||
* @return consumer name
|
||||
*/
|
||||
String buildConsumerName(String consumerGroup, int shard) {
|
||||
String instanceId = properties.getRedis().getConsumerInstanceId();
|
||||
String safeInstanceId = UNSAFE_CONSUMER_NAME_CHARS.matcher(instanceId).replaceAll("-");
|
||||
return consumerGroup + "-" + shard + "-" + safeInstanceId;
|
||||
}
|
||||
|
||||
List<MapRecord<String, Object, Object>> reclaimPending(String streamKey, String group, String consumerName) {
|
||||
Duration idle = properties.getRedis().getPendingClaimIdle();
|
||||
try (RedisConnection connection = redisConnectionFactory.getConnection()) {
|
||||
RedisStreamCommands.XPendingOptions options = RedisStreamCommands.XPendingOptions
|
||||
@@ -200,7 +226,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
var pendingMessages = connection.streamCommands()
|
||||
.xPending(streamKey.getBytes(StandardCharsets.UTF_8), group, options);
|
||||
if (pendingMessages == null || pendingMessages.isEmpty()) {
|
||||
return;
|
||||
return List.of();
|
||||
}
|
||||
List<RecordId> ids = new ArrayList<>();
|
||||
for (PendingMessage pendingMessage : pendingMessages) {
|
||||
@@ -209,15 +235,16 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
}
|
||||
if (ids.isEmpty()) {
|
||||
return;
|
||||
return List.of();
|
||||
}
|
||||
stringRedisTemplate.opsForStream().claim(
|
||||
List<MapRecord<String, Object, Object>> records = stringRedisTemplate.opsForStream().claim(
|
||||
streamKey,
|
||||
group,
|
||||
consumerName,
|
||||
idle,
|
||||
ids.toArray(new RecordId[0])
|
||||
);
|
||||
return records == null ? List.of() : records;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,7 +260,7 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
}
|
||||
|
||||
private List<MQMessage> toMessages(String streamKey, List<MapRecord<String, Object, Object>> records) {
|
||||
List<MQMessage> toMessages(String streamKey, List<MapRecord<String, Object, Object>> records) {
|
||||
List<MQMessage> messages = new ArrayList<>(records.size());
|
||||
for (MapRecord<String, Object, Object> record : records) {
|
||||
Object payload = record.getValue().get("payload");
|
||||
@@ -269,7 +296,15 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
}
|
||||
|
||||
private void handleMessages(MQConsumerHandler handler, String streamKey, String group, List<MQMessage> messages) throws Exception {
|
||||
void handleMessages(MQConsumerHandler handler,
|
||||
MQSubscription subscription,
|
||||
String streamKey,
|
||||
String group,
|
||||
List<MQMessage> messages) throws Exception {
|
||||
if (!subscription.isBatchEnabled()) {
|
||||
handleMessagesIndividually(handler, streamKey, group, messages);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
LOG.info("MQ 开始批量处理消息: group={}, streamKey={}, count={}, handler={}",
|
||||
group, streamKey, messages.size(), handler.getClass().getSimpleName());
|
||||
@@ -288,6 +323,13 @@ public class RedisMQConsumerContainer implements MQConsumerContainer, SmartLifec
|
||||
}
|
||||
}
|
||||
|
||||
handleMessagesIndividually(handler, streamKey, group, messages);
|
||||
}
|
||||
|
||||
private void handleMessagesIndividually(MQConsumerHandler handler,
|
||||
String streamKey,
|
||||
String group,
|
||||
List<MQMessage> messages) {
|
||||
for (MQMessage message : messages) {
|
||||
try {
|
||||
LOG.info("MQ 开始单条处理消息: group={}, streamKey={}, messageId={}, handler={}",
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
package tech.easyflow.common.mq.redis;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.connection.RedisConnection;
|
||||
import org.springframework.data.redis.connection.RedisConnectionFactory;
|
||||
import org.springframework.data.redis.connection.RedisStreamCommands;
|
||||
import org.springframework.data.redis.connection.stream.Consumer;
|
||||
import org.springframework.data.redis.connection.stream.MapRecord;
|
||||
import org.springframework.data.redis.connection.stream.PendingMessage;
|
||||
import org.springframework.data.redis.connection.stream.PendingMessages;
|
||||
import org.springframework.data.redis.connection.stream.RecordId;
|
||||
import org.springframework.data.redis.core.StreamOperations;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import tech.easyflow.common.mq.config.MQProperties;
|
||||
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
||||
import tech.easyflow.common.mq.core.MQDeadLetterService;
|
||||
import tech.easyflow.common.mq.core.MQMessage;
|
||||
import tech.easyflow.common.mq.core.MQMessageConverter;
|
||||
import tech.easyflow.common.mq.core.MQSubscription;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* {@link RedisMQConsumerContainer} 回归测试。
|
||||
*/
|
||||
public class RedisMQConsumerContainerTest {
|
||||
|
||||
/**
|
||||
* 验证 consumer name 包含稳定实例 ID,且消费组名称不被改变。
|
||||
*/
|
||||
@Test
|
||||
public void buildConsumerNameShouldAppendSanitizedInstanceId() {
|
||||
MQProperties properties = new MQProperties();
|
||||
properties.getRedis().setConsumerInstanceId("node/a:1");
|
||||
RedisMQConsumerContainer container = new RedisMQConsumerContainer(
|
||||
null,
|
||||
null,
|
||||
properties,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
List.of()
|
||||
);
|
||||
|
||||
String consumerName = container.buildConsumerName("chat-persist", 2);
|
||||
|
||||
Assert.assertEquals("chat-persist-2-node-a-1", consumerName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证关闭批量消费后,容器按单条处理并独立确认消息。
|
||||
*
|
||||
* @throws Exception 消息处理异常
|
||||
*/
|
||||
@Test
|
||||
public void handleMessagesShouldProcessIndividuallyWhenBatchDisabled() throws Exception {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
StreamOperations<String, Object, Object> streamOperations = Mockito.mock(StreamOperations.class);
|
||||
Mockito.when(redisTemplate.opsForStream()).thenReturn(streamOperations);
|
||||
RecordingHandler handler = new RecordingHandler();
|
||||
MQSubscription subscription = new MQSubscription();
|
||||
subscription.setBatchEnabled(false);
|
||||
RedisMQConsumerContainer container = container(redisTemplate, null);
|
||||
MQMessage first = message("message-1", "1-0");
|
||||
MQMessage second = message("message-2", "2-0");
|
||||
|
||||
container.handleMessages(handler, subscription, "stream-1", "group-1", List.of(first, second));
|
||||
|
||||
Assert.assertEquals(List.of(List.of("message-1"), List.of("message-2")), handler.calls);
|
||||
Mockito.verify(streamOperations).acknowledge("stream-1", "group-1", "1-0");
|
||||
Mockito.verify(streamOperations).acknowledge("stream-1", "group-1", "2-0");
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 pending 消息被 claim 后可以转换为 MQ 消息继续消费。
|
||||
*/
|
||||
@Test
|
||||
public void reclaimPendingShouldReturnClaimedRecordsForConsumption() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
StreamOperations<String, Object, Object> streamOperations = Mockito.mock(StreamOperations.class);
|
||||
Mockito.when(redisTemplate.opsForStream()).thenReturn(streamOperations);
|
||||
RedisConnectionFactory connectionFactory = Mockito.mock(RedisConnectionFactory.class);
|
||||
RedisConnection connection = Mockito.mock(RedisConnection.class);
|
||||
RedisStreamCommands streamCommands = Mockito.mock(RedisStreamCommands.class);
|
||||
Mockito.when(connectionFactory.getConnection()).thenReturn(connection);
|
||||
Mockito.when(connection.streamCommands()).thenReturn(streamCommands);
|
||||
PendingMessage pendingMessage = new PendingMessage(
|
||||
RecordId.of("1-0"), Consumer.from("group-1", "old-consumer"), Duration.ofMinutes(2), 1);
|
||||
Mockito.when(streamCommands.xPending(
|
||||
ArgumentMatchers.eq("stream-1".getBytes(java.nio.charset.StandardCharsets.UTF_8)),
|
||||
ArgumentMatchers.eq("group-1"),
|
||||
ArgumentMatchers.any(RedisStreamCommands.XPendingOptions.class)))
|
||||
.thenReturn(new PendingMessages("group-1", List.of(pendingMessage)));
|
||||
Map<Object, Object> payload = Map.of("payload", "message-1");
|
||||
MapRecord<String, Object, Object> record = MapRecord
|
||||
.create("stream-1", payload)
|
||||
.withId(RecordId.of("1-0"));
|
||||
Mockito.when(streamOperations.claim(
|
||||
ArgumentMatchers.eq("stream-1"),
|
||||
ArgumentMatchers.eq("group-1"),
|
||||
ArgumentMatchers.eq("consumer-1"),
|
||||
ArgumentMatchers.any(Duration.class),
|
||||
ArgumentMatchers.any(RecordId[].class)))
|
||||
.thenReturn(List.of(record));
|
||||
RedisMQConsumerContainer container = container(redisTemplate, connectionFactory);
|
||||
|
||||
List<MapRecord<String, Object, Object>> records =
|
||||
container.reclaimPending("stream-1", "group-1", "consumer-1");
|
||||
List<MQMessage> messages = container.toMessages("stream-1", records);
|
||||
|
||||
Assert.assertEquals(1, records.size());
|
||||
Assert.assertEquals(1, messages.size());
|
||||
Assert.assertEquals("message-1", messages.get(0).getMessageId());
|
||||
Assert.assertEquals("1-0", messages.get(0).getStreamMessageId());
|
||||
}
|
||||
|
||||
private RedisMQConsumerContainer container(StringRedisTemplate redisTemplate,
|
||||
RedisConnectionFactory connectionFactory) {
|
||||
MQProperties properties = new MQProperties();
|
||||
return new RedisMQConsumerContainer(
|
||||
connectionFactory,
|
||||
redisTemplate,
|
||||
properties,
|
||||
new PlainMessageConverter(),
|
||||
Mockito.mock(MQDeadLetterService.class),
|
||||
null,
|
||||
List.of()
|
||||
);
|
||||
}
|
||||
|
||||
private MQMessage message(String messageId, String streamMessageId) {
|
||||
MQMessage message = new MQMessage();
|
||||
message.setMessageId(messageId);
|
||||
message.setStreamMessageId(streamMessageId);
|
||||
return message;
|
||||
}
|
||||
|
||||
private static final class RecordingHandler implements MQConsumerHandler {
|
||||
|
||||
private final List<List<String>> calls = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public MQSubscription subscription() {
|
||||
return new MQSubscription();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(List<MQMessage> messages) {
|
||||
calls.add(messages.stream().map(MQMessage::getMessageId).toList());
|
||||
}
|
||||
}
|
||||
|
||||
private static final class PlainMessageConverter implements MQMessageConverter {
|
||||
|
||||
@Override
|
||||
public String serialize(MQMessage message) {
|
||||
return message.getMessageId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQMessage deserialize(String payload) {
|
||||
MQMessage message = new MQMessage();
|
||||
message.setMessageId(payload);
|
||||
return message;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,10 @@
|
||||
<groupId>tech.easyflow</groupId>
|
||||
<artifactId>easyflow-common-cache</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>tech.easyflow</groupId>
|
||||
<artifactId>easyflow-common-mq</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>tech.easyflow</groupId>
|
||||
<artifactId>easyflow-common-web</artifactId>
|
||||
@@ -63,5 +67,11 @@
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.12.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package tech.easyflow.agent.config;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* Agent 运行态生产化配置。
|
||||
@@ -15,6 +17,36 @@ public class AgentRuntimeProperties {
|
||||
*/
|
||||
private Duration sessionCacheTtl = Duration.ofHours(24);
|
||||
|
||||
/**
|
||||
* 当前 Agent 运行实例 ID。
|
||||
*/
|
||||
private String instanceId = defaultInstanceId();
|
||||
|
||||
/**
|
||||
* Agent 运行路由 TTL。
|
||||
*/
|
||||
private Duration routeTtl = Duration.ofHours(24);
|
||||
|
||||
/**
|
||||
* Agent 运行命令 topic 前缀。
|
||||
*/
|
||||
private String commandTopicPrefix = "easyflow:agent-runtime-command";
|
||||
|
||||
/**
|
||||
* Agent 运行命令结果等待超时时间。
|
||||
*/
|
||||
private Duration commandResultTimeout = Duration.ofSeconds(5);
|
||||
|
||||
/**
|
||||
* Agent 运行命令结果缓存 TTL。
|
||||
*/
|
||||
private Duration commandResultTtl = Duration.ofMinutes(5);
|
||||
|
||||
/**
|
||||
* 当前进程启动代 ID。
|
||||
*/
|
||||
private final String bootId = UUID.randomUUID().toString();
|
||||
|
||||
/**
|
||||
* HITL pending 默认过期时间。
|
||||
*/
|
||||
@@ -53,6 +85,107 @@ public class AgentRuntimeProperties {
|
||||
this.sessionCacheTtl = sessionCacheTtl == null ? Duration.ofHours(24) : sessionCacheTtl;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前 Agent 运行实例 ID。
|
||||
*
|
||||
* @return 实例 ID
|
||||
*/
|
||||
public String getInstanceId() {
|
||||
return instanceId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置当前 Agent 运行实例 ID。
|
||||
*
|
||||
* @param instanceId 实例 ID
|
||||
*/
|
||||
public void setInstanceId(String instanceId) {
|
||||
this.instanceId = StringUtils.hasText(instanceId) ? instanceId.trim() : defaultInstanceId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Agent 运行路由 TTL。
|
||||
*
|
||||
* @return 路由 TTL
|
||||
*/
|
||||
public Duration getRouteTtl() {
|
||||
return routeTtl;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 Agent 运行路由 TTL。
|
||||
*
|
||||
* @param routeTtl 路由 TTL
|
||||
*/
|
||||
public void setRouteTtl(Duration routeTtl) {
|
||||
this.routeTtl = routeTtl == null ? Duration.ofHours(24) : routeTtl;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Agent 运行命令 topic 前缀。
|
||||
*
|
||||
* @return 命令 topic 前缀
|
||||
*/
|
||||
public String getCommandTopicPrefix() {
|
||||
return commandTopicPrefix;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 Agent 运行命令 topic 前缀。
|
||||
*
|
||||
* @param commandTopicPrefix 命令 topic 前缀
|
||||
*/
|
||||
public void setCommandTopicPrefix(String commandTopicPrefix) {
|
||||
this.commandTopicPrefix = StringUtils.hasText(commandTopicPrefix)
|
||||
? commandTopicPrefix.trim()
|
||||
: "easyflow:agent-runtime-command";
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Agent 运行命令结果等待超时时间。
|
||||
*
|
||||
* @return 等待超时时间
|
||||
*/
|
||||
public Duration getCommandResultTimeout() {
|
||||
return commandResultTimeout;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 Agent 运行命令结果等待超时时间。
|
||||
*
|
||||
* @param commandResultTimeout 等待超时时间
|
||||
*/
|
||||
public void setCommandResultTimeout(Duration commandResultTimeout) {
|
||||
this.commandResultTimeout = commandResultTimeout == null ? Duration.ofSeconds(5) : commandResultTimeout;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Agent 运行命令结果缓存 TTL。
|
||||
*
|
||||
* @return 结果缓存 TTL
|
||||
*/
|
||||
public Duration getCommandResultTtl() {
|
||||
return commandResultTtl;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 Agent 运行命令结果缓存 TTL。
|
||||
*
|
||||
* @param commandResultTtl 结果缓存 TTL
|
||||
*/
|
||||
public void setCommandResultTtl(Duration commandResultTtl) {
|
||||
this.commandResultTtl = commandResultTtl == null ? Duration.ofMinutes(5) : commandResultTtl;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前进程启动代 ID。
|
||||
*
|
||||
* @return 启动代 ID
|
||||
*/
|
||||
public String getBootId() {
|
||||
return bootId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 HITL pending 默认过期时间。
|
||||
*
|
||||
@@ -124,4 +257,16 @@ public class AgentRuntimeProperties {
|
||||
public void setLockRenewInterval(Duration lockRenewInterval) {
|
||||
this.lockRenewInterval = lockRenewInterval == null ? Duration.ofMinutes(1) : lockRenewInterval;
|
||||
}
|
||||
|
||||
private static String defaultInstanceId() {
|
||||
String envInstanceId = System.getenv("EASYFLOW_INSTANCE_ID");
|
||||
if (StringUtils.hasText(envInstanceId)) {
|
||||
return envInstanceId.trim();
|
||||
}
|
||||
String hostName = System.getenv("HOSTNAME");
|
||||
if (StringUtils.hasText(hostName)) {
|
||||
return hostName.trim();
|
||||
}
|
||||
return UUID.randomUUID().toString();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程命令动作。
|
||||
*/
|
||||
public enum AgentRuntimeCommandAction {
|
||||
|
||||
/**
|
||||
* 批准工具执行。
|
||||
*/
|
||||
APPROVE,
|
||||
|
||||
/**
|
||||
* 拒绝工具执行。
|
||||
*/
|
||||
REJECT
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.agent.runtime.AgentRunService;
|
||||
import tech.easyflow.common.mq.config.MQProperties;
|
||||
import tech.easyflow.common.mq.core.MQConsumerHandler;
|
||||
import tech.easyflow.common.mq.core.MQMessage;
|
||||
import tech.easyflow.common.mq.core.MQSubscription;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程命令消费者。
|
||||
*/
|
||||
@Component
|
||||
public class AgentRuntimeCommandConsumer implements MQConsumerHandler {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeCommandConsumer.class);
|
||||
|
||||
private final ObjectMapper objectMapper;
|
||||
private final AgentRuntimeProperties properties;
|
||||
private final MQProperties mqProperties;
|
||||
private final AgentRunService agentRunService;
|
||||
private final AgentRuntimeCommandResultRegistry resultRegistry;
|
||||
|
||||
/**
|
||||
* 创建 Agent 运行态远程命令消费者。
|
||||
*
|
||||
* @param objectMapper JSON 序列化器
|
||||
* @param properties Agent 运行配置
|
||||
* @param mqProperties MQ 配置
|
||||
* @param agentRunService Agent 运行服务
|
||||
* @param resultRegistry 远程命令结果注册表
|
||||
*/
|
||||
public AgentRuntimeCommandConsumer(ObjectMapper objectMapper,
|
||||
AgentRuntimeProperties properties,
|
||||
MQProperties mqProperties,
|
||||
AgentRunService agentRunService,
|
||||
AgentRuntimeCommandResultRegistry resultRegistry) {
|
||||
this.objectMapper = objectMapper;
|
||||
this.properties = properties;
|
||||
this.mqProperties = mqProperties;
|
||||
this.agentRunService = agentRunService;
|
||||
this.resultRegistry = resultRegistry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MQSubscription subscription() {
|
||||
MQSubscription subscription = new MQSubscription();
|
||||
subscription.setTopic(commandTopic());
|
||||
subscription.setConsumerGroup(commandTopic());
|
||||
subscription.setShardCount(Math.max(mqProperties.getRedis().getChatPersistShardCount(), 1));
|
||||
subscription.setBatchEnabled(false);
|
||||
return subscription;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(List<MQMessage> messages) {
|
||||
if (messages == null || messages.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (MQMessage message : messages) {
|
||||
try {
|
||||
handleCommand(message, objectMapper.readValue(message.getBody(), AgentRuntimeCommandMessage.class));
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Agent 远程运行命令解析失败: messageId={}", message.getMessageId(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void handleCommand(MQMessage message, AgentRuntimeCommandMessage command) {
|
||||
if (command == null || command.getAction() == null) {
|
||||
LOG.warn("跳过非法 Agent 远程运行命令: messageId={}", message.getMessageId());
|
||||
return;
|
||||
}
|
||||
if (!properties.getInstanceId().equals(command.getTargetNodeId())) {
|
||||
LOG.warn("跳过非本节点 Agent 远程运行命令: messageId={}, targetNodeId={}, currentNodeId={}",
|
||||
message.getMessageId(), command.getTargetNodeId(), properties.getInstanceId());
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (command.getAction() == AgentRuntimeCommandAction.APPROVE) {
|
||||
agentRunService.approveRuntimeLocal(
|
||||
command.getRequestId(), command.getResumeToken(), command.getOperatorId(), command.getUserId());
|
||||
} else if (command.getAction() == AgentRuntimeCommandAction.REJECT) {
|
||||
agentRunService.rejectRuntimeLocal(
|
||||
command.getRequestId(), command.getResumeToken(), command.getReason(),
|
||||
command.getOperatorId(), command.getUserId());
|
||||
} else {
|
||||
markFailureQuietly(command, new IllegalArgumentException("不支持的 Agent 远程运行命令"));
|
||||
LOG.warn("跳过不支持的 Agent 远程运行命令: messageId={}, commandId={}, action={}",
|
||||
message.getMessageId(), command.getCommandId(), command.getAction());
|
||||
return;
|
||||
}
|
||||
} catch (RuntimeException e) {
|
||||
markFailureQuietly(command, e);
|
||||
LOG.warn("Agent 远程运行命令处理失败: messageId={}, commandId={}",
|
||||
message.getMessageId(), command.getCommandId(), e);
|
||||
return;
|
||||
}
|
||||
markSuccessQuietly(command);
|
||||
}
|
||||
|
||||
private String commandTopic() {
|
||||
return properties.getCommandTopicPrefix() + ":" + properties.getInstanceId();
|
||||
}
|
||||
|
||||
private void markSuccessQuietly(AgentRuntimeCommandMessage command) {
|
||||
try {
|
||||
resultRegistry.markSuccess(command.getCommandId());
|
||||
} catch (RuntimeException e) {
|
||||
LOG.error("Agent 远程运行命令成功结果写入失败: commandId={}", command.getCommandId(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private void markFailureQuietly(AgentRuntimeCommandMessage command, RuntimeException cause) {
|
||||
try {
|
||||
resultRegistry.markFailure(command.getCommandId(), cause.getMessage());
|
||||
} catch (RuntimeException e) {
|
||||
LOG.error("Agent 远程运行命令失败结果写入失败: commandId={}", command.getCommandId(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程恢复命令消息。
|
||||
*/
|
||||
public class AgentRuntimeCommandMessage {
|
||||
|
||||
private String commandId;
|
||||
private String requestId;
|
||||
private String resumeToken;
|
||||
private AgentRuntimeCommandAction action;
|
||||
private String reason;
|
||||
private BigInteger operatorId;
|
||||
private String userId;
|
||||
private String targetNodeId;
|
||||
private Date occurredAt;
|
||||
|
||||
public String getCommandId() {
|
||||
return commandId;
|
||||
}
|
||||
|
||||
public void setCommandId(String commandId) {
|
||||
this.commandId = commandId;
|
||||
}
|
||||
|
||||
public String getRequestId() {
|
||||
return requestId;
|
||||
}
|
||||
|
||||
public void setRequestId(String requestId) {
|
||||
this.requestId = requestId;
|
||||
}
|
||||
|
||||
public String getResumeToken() {
|
||||
return resumeToken;
|
||||
}
|
||||
|
||||
public void setResumeToken(String resumeToken) {
|
||||
this.resumeToken = resumeToken;
|
||||
}
|
||||
|
||||
public AgentRuntimeCommandAction getAction() {
|
||||
return action;
|
||||
}
|
||||
|
||||
public void setAction(AgentRuntimeCommandAction action) {
|
||||
this.action = action;
|
||||
}
|
||||
|
||||
public String getReason() {
|
||||
return reason;
|
||||
}
|
||||
|
||||
public void setReason(String reason) {
|
||||
this.reason = reason;
|
||||
}
|
||||
|
||||
public BigInteger getOperatorId() {
|
||||
return operatorId;
|
||||
}
|
||||
|
||||
public void setOperatorId(BigInteger operatorId) {
|
||||
this.operatorId = operatorId;
|
||||
}
|
||||
|
||||
public String getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public void setUserId(String userId) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
public String getTargetNodeId() {
|
||||
return targetNodeId;
|
||||
}
|
||||
|
||||
public void setTargetNodeId(String targetNodeId) {
|
||||
this.targetNodeId = targetNodeId;
|
||||
}
|
||||
|
||||
public Date getOccurredAt() {
|
||||
return occurredAt;
|
||||
}
|
||||
|
||||
public void setOccurredAt(Date occurredAt) {
|
||||
this.occurredAt = occurredAt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Service;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.common.mq.core.MQMessage;
|
||||
import tech.easyflow.common.mq.core.MQProducer;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Date;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程命令生产者。
|
||||
*/
|
||||
@Service
|
||||
public class AgentRuntimeCommandProducer {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeCommandProducer.class);
|
||||
|
||||
private final MQProducer mqProducer;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final AgentRuntimeProperties properties;
|
||||
private final AgentRuntimeCommandResultRegistry resultRegistry;
|
||||
|
||||
/**
|
||||
* 测试子类构造器。
|
||||
*/
|
||||
protected AgentRuntimeCommandProducer() {
|
||||
this.mqProducer = null;
|
||||
this.objectMapper = null;
|
||||
this.properties = null;
|
||||
this.resultRegistry = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Agent 运行态远程命令生产者。
|
||||
*
|
||||
* @param mqProducer MQ 生产者
|
||||
* @param objectMapper JSON 序列化器
|
||||
* @param properties Agent 运行配置
|
||||
* @param resultRegistry 远程命令结果注册表
|
||||
*/
|
||||
public AgentRuntimeCommandProducer(MQProducer mqProducer,
|
||||
ObjectMapper objectMapper,
|
||||
AgentRuntimeProperties properties,
|
||||
AgentRuntimeCommandResultRegistry resultRegistry) {
|
||||
this.mqProducer = mqProducer;
|
||||
this.objectMapper = objectMapper;
|
||||
this.properties = properties;
|
||||
this.resultRegistry = resultRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* 投递远程批准命令。
|
||||
*
|
||||
* @param targetNodeId 目标节点 ID
|
||||
* @param requestId 请求 ID
|
||||
* @param resumeToken 恢复令牌
|
||||
* @param operatorId 操作人 ID
|
||||
* @param userId 用户 ID
|
||||
*/
|
||||
public void sendApprove(String targetNodeId,
|
||||
String requestId,
|
||||
String resumeToken,
|
||||
BigInteger operatorId,
|
||||
String userId) {
|
||||
sendAndWait(targetNodeId, requestId, resumeToken, AgentRuntimeCommandAction.APPROVE, null, operatorId, userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 投递远程拒绝命令。
|
||||
*
|
||||
* @param targetNodeId 目标节点 ID
|
||||
* @param requestId 请求 ID
|
||||
* @param resumeToken 恢复令牌
|
||||
* @param reason 拒绝原因
|
||||
* @param operatorId 操作人 ID
|
||||
* @param userId 用户 ID
|
||||
*/
|
||||
public void sendReject(String targetNodeId,
|
||||
String requestId,
|
||||
String resumeToken,
|
||||
String reason,
|
||||
BigInteger operatorId,
|
||||
String userId) {
|
||||
sendAndWait(targetNodeId, requestId, resumeToken, AgentRuntimeCommandAction.REJECT, reason, operatorId, userId);
|
||||
}
|
||||
|
||||
private void sendAndWait(String targetNodeId,
|
||||
String requestId,
|
||||
String resumeToken,
|
||||
AgentRuntimeCommandAction action,
|
||||
String reason,
|
||||
BigInteger operatorId,
|
||||
String userId) {
|
||||
if (targetNodeId == null || targetNodeId.isBlank()) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
AgentRuntimeCommandMessage command = new AgentRuntimeCommandMessage();
|
||||
command.setCommandId(UUID.randomUUID().toString());
|
||||
command.setRequestId(requestId);
|
||||
command.setResumeToken(resumeToken);
|
||||
command.setAction(action);
|
||||
command.setReason(reason);
|
||||
command.setOperatorId(operatorId);
|
||||
command.setUserId(userId);
|
||||
command.setTargetNodeId(targetNodeId);
|
||||
command.setOccurredAt(new Date());
|
||||
|
||||
MQMessage message = new MQMessage();
|
||||
message.setMessageId(command.getCommandId());
|
||||
message.setTopic(commandTopic(targetNodeId));
|
||||
message.setKey(command.getCommandId());
|
||||
message.setCreatedAt(command.getOccurredAt());
|
||||
try {
|
||||
message.setBody(objectMapper.writeValueAsString(command));
|
||||
String recordId = mqProducer.send(message);
|
||||
LOG.info("Agent 远程运行命令已投递: action={}, requestId={}, targetNodeId={}, recordId={}",
|
||||
action, requestId, targetNodeId, recordId);
|
||||
AgentRuntimeCommandResult result = resultRegistry.waitForResult(command.getCommandId());
|
||||
if (!result.isSuccess()) {
|
||||
throw new BusinessException(result.getMessage());
|
||||
}
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new BusinessException("Agent 运行命令序列化失败");
|
||||
} catch (BusinessException e) {
|
||||
throw e;
|
||||
} catch (RuntimeException e) {
|
||||
LOG.error("Agent 远程运行命令投递失败: action={}, requestId={}, targetNodeId={}",
|
||||
action, requestId, targetNodeId, e);
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
} finally {
|
||||
deleteResultQuietly(command.getCommandId());
|
||||
}
|
||||
}
|
||||
|
||||
private String commandTopic(String nodeId) {
|
||||
return properties.getCommandTopicPrefix() + ":" + nodeId;
|
||||
}
|
||||
|
||||
private void deleteResultQuietly(String commandId) {
|
||||
try {
|
||||
resultRegistry.deleteResult(commandId);
|
||||
} catch (RuntimeException e) {
|
||||
LOG.warn("Agent 远程运行命令结果清理失败,等待 TTL 兜底: commandId={}", commandId, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程命令结果。
|
||||
*/
|
||||
public class AgentRuntimeCommandResult {
|
||||
|
||||
private boolean success;
|
||||
private String message;
|
||||
|
||||
/**
|
||||
* 判断命令是否执行成功。
|
||||
*
|
||||
* @return true 表示执行成功
|
||||
*/
|
||||
public boolean isSuccess() {
|
||||
return success;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置命令是否执行成功。
|
||||
*
|
||||
* @param success 是否执行成功
|
||||
*/
|
||||
public void setSuccess(boolean success) {
|
||||
this.success = success;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取结果消息。
|
||||
*
|
||||
* @return 结果消息
|
||||
*/
|
||||
public String getMessage() {
|
||||
return message;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置结果消息。
|
||||
*
|
||||
* @param message 结果消息
|
||||
*/
|
||||
public void setMessage(String message) {
|
||||
this.message = message;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
/**
|
||||
* Agent 运行态远程命令结果注册表。
|
||||
*/
|
||||
@Component
|
||||
public class AgentRuntimeCommandResultRegistry {
|
||||
|
||||
private static final String RESULT_PREFIX = "easyflow:agent:runtime:command-result:";
|
||||
private static final long POLL_INTERVAL_MILLIS = 50L;
|
||||
|
||||
private final StringRedisTemplate stringRedisTemplate;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final AgentRuntimeProperties properties;
|
||||
|
||||
/**
|
||||
* 创建 Agent 运行态远程命令结果注册表。
|
||||
*
|
||||
* @param stringRedisTemplate Redis 字符串模板
|
||||
* @param objectMapper JSON 序列化器
|
||||
* @param properties Agent 运行配置
|
||||
*/
|
||||
public AgentRuntimeCommandResultRegistry(StringRedisTemplate stringRedisTemplate,
|
||||
ObjectMapper objectMapper,
|
||||
AgentRuntimeProperties properties) {
|
||||
this.stringRedisTemplate = stringRedisTemplate;
|
||||
this.objectMapper = objectMapper;
|
||||
this.properties = properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* 写入成功结果。
|
||||
*
|
||||
* @param commandId 命令 ID
|
||||
*/
|
||||
public void markSuccess(String commandId) {
|
||||
AgentRuntimeCommandResult result = new AgentRuntimeCommandResult();
|
||||
result.setSuccess(true);
|
||||
result.setMessage("OK");
|
||||
writeResult(commandId, result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 写入失败结果。
|
||||
*
|
||||
* @param commandId 命令 ID
|
||||
* @param message 失败消息
|
||||
*/
|
||||
public void markFailure(String commandId, String message) {
|
||||
AgentRuntimeCommandResult result = new AgentRuntimeCommandResult();
|
||||
result.setSuccess(false);
|
||||
result.setMessage(message == null || message.isBlank() ? "Agent 运行节点不可用,请重新发起对话" : message);
|
||||
writeResult(commandId, result);
|
||||
}
|
||||
|
||||
/**
|
||||
* 等待远程命令结果。
|
||||
*
|
||||
* @param commandId 命令 ID
|
||||
* @return 命令结果
|
||||
*/
|
||||
public AgentRuntimeCommandResult waitForResult(String commandId) {
|
||||
long deadline = System.nanoTime() + properties.getCommandResultTimeout().toNanos();
|
||||
while (System.nanoTime() <= deadline) {
|
||||
AgentRuntimeCommandResult result = readResult(commandId);
|
||||
if (result != null) {
|
||||
return result;
|
||||
}
|
||||
sleep();
|
||||
}
|
||||
throw new BusinessException("Agent 运行节点响应超时,请稍后重试");
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除远程命令结果。
|
||||
*
|
||||
* @param commandId 命令 ID
|
||||
*/
|
||||
public void deleteResult(String commandId) {
|
||||
if (commandId == null || commandId.isBlank()) {
|
||||
return;
|
||||
}
|
||||
stringRedisTemplate.delete(resultKey(commandId));
|
||||
}
|
||||
|
||||
private AgentRuntimeCommandResult readResult(String commandId) {
|
||||
if (commandId == null || commandId.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
String value = stringRedisTemplate.opsForValue().get(resultKey(commandId));
|
||||
if (value == null || value.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return objectMapper.readValue(value, AgentRuntimeCommandResult.class);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new BusinessException("Agent 运行命令结果解析失败");
|
||||
}
|
||||
}
|
||||
|
||||
private void writeResult(String commandId, AgentRuntimeCommandResult result) {
|
||||
if (commandId == null || commandId.isBlank()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
stringRedisTemplate.opsForValue().set(
|
||||
resultKey(commandId),
|
||||
objectMapper.writeValueAsString(result),
|
||||
properties.getCommandResultTtl());
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalStateException("Agent 运行命令结果序列化失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private String resultKey(String commandId) {
|
||||
return RESULT_PREFIX + commandId;
|
||||
}
|
||||
|
||||
private void sleep() {
|
||||
try {
|
||||
Thread.sleep(POLL_INTERVAL_MILLIS);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new BusinessException("Agent 运行节点响应等待被中断");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* Agent 运行节点心跳维护器。
|
||||
*/
|
||||
@Component
|
||||
public class AgentRuntimeNodeHeartbeat {
|
||||
|
||||
private static final Duration HEARTBEAT_TTL = Duration.ofSeconds(90);
|
||||
|
||||
private final AgentRuntimeRouteRegistry routeRegistry;
|
||||
|
||||
/**
|
||||
* 创建 Agent 运行节点心跳维护器。
|
||||
*
|
||||
* @param routeRegistry Agent 运行态 Redis 路由注册表
|
||||
*/
|
||||
public AgentRuntimeNodeHeartbeat(AgentRuntimeRouteRegistry routeRegistry) {
|
||||
this.routeRegistry = routeRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* 启动时立即写入一次当前节点心跳。
|
||||
*/
|
||||
@PostConstruct
|
||||
public void init() {
|
||||
refresh();
|
||||
}
|
||||
|
||||
/**
|
||||
* 定期刷新当前节点心跳。
|
||||
*/
|
||||
@Scheduled(fixedDelayString = "${easyflow.agent.runtime.node-heartbeat-delay:30000}", initialDelay = 30000L)
|
||||
public void refresh() {
|
||||
routeRegistry.heartbeat(HEARTBEAT_TTL);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
/**
|
||||
* Agent 运行态 owner 路由。
|
||||
*/
|
||||
public class AgentRuntimeRoute {
|
||||
|
||||
private String nodeId;
|
||||
private String bootId;
|
||||
|
||||
/**
|
||||
* 获取 owner 节点 ID。
|
||||
*
|
||||
* @return owner 节点 ID
|
||||
*/
|
||||
public String getNodeId() {
|
||||
return nodeId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 owner 节点 ID。
|
||||
*
|
||||
* @param nodeId owner 节点 ID
|
||||
*/
|
||||
public void setNodeId(String nodeId) {
|
||||
this.nodeId = nodeId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 owner 启动代 ID。
|
||||
*
|
||||
* @return 启动代 ID
|
||||
*/
|
||||
public String getBootId() {
|
||||
return bootId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 owner 启动代 ID。
|
||||
*
|
||||
* @param bootId 启动代 ID
|
||||
*/
|
||||
public void setBootId(String bootId) {
|
||||
this.bootId = bootId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* Agent 运行态 Redis 路由注册表。
|
||||
*/
|
||||
@Component
|
||||
public class AgentRuntimeRouteRegistry {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AgentRuntimeRouteRegistry.class);
|
||||
|
||||
private static final String REQUEST_ROUTE_PREFIX = "easyflow:agent:runtime:request:";
|
||||
private static final String TOKEN_ROUTE_PREFIX = "easyflow:agent:runtime:resume-token:";
|
||||
private static final String NODE_HEARTBEAT_PREFIX = "easyflow:agent:runtime:node:";
|
||||
|
||||
private final StringRedisTemplate stringRedisTemplate;
|
||||
private final AgentRuntimeProperties properties;
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
/**
|
||||
* 创建 Agent 运行态 Redis 路由注册表。
|
||||
*
|
||||
* @param stringRedisTemplate Redis 字符串模板
|
||||
* @param properties Agent 运行配置
|
||||
* @param objectMapper JSON 序列化器
|
||||
*/
|
||||
public AgentRuntimeRouteRegistry(StringRedisTemplate stringRedisTemplate,
|
||||
AgentRuntimeProperties properties,
|
||||
ObjectMapper objectMapper) {
|
||||
this.stringRedisTemplate = stringRedisTemplate;
|
||||
this.properties = properties;
|
||||
this.objectMapper = objectMapper;
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册运行请求 owner 节点。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
*/
|
||||
public void registerRun(String requestId) {
|
||||
if (requestId == null || requestId.isBlank()) {
|
||||
return;
|
||||
}
|
||||
stringRedisTemplate.opsForValue().set(requestKey(requestId), serializeRoute(currentRoute()), properties.getRouteTtl());
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册恢复令牌与请求 ID 的关系。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
* @param resumeToken 恢复令牌
|
||||
*/
|
||||
public void registerResumeToken(String requestId, String resumeToken) {
|
||||
if (requestId == null || requestId.isBlank() || resumeToken == null || resumeToken.isBlank()) {
|
||||
return;
|
||||
}
|
||||
stringRedisTemplate.opsForValue().set(tokenKey(resumeToken), requestId, properties.getRouteTtl());
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询请求 ID 所属节点。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
* @return owner 节点 ID
|
||||
*/
|
||||
public String findOwnerNode(String requestId) {
|
||||
AgentRuntimeRoute route = findOwnerRoute(requestId);
|
||||
return route == null ? null : route.getNodeId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询请求 ID 所属路由。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
* @return owner 路由
|
||||
*/
|
||||
public AgentRuntimeRoute findOwnerRoute(String requestId) {
|
||||
if (requestId == null || requestId.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
String value = stringRedisTemplate.opsForValue().get(requestKey(requestId));
|
||||
if (value == null || value.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
return deserializeRoute(value);
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据恢复令牌查询请求 ID。
|
||||
*
|
||||
* @param resumeToken 恢复令牌
|
||||
* @return 请求 ID
|
||||
*/
|
||||
public String findRequestIdByResumeToken(String resumeToken) {
|
||||
if (resumeToken == null || resumeToken.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
return stringRedisTemplate.opsForValue().get(tokenKey(resumeToken));
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除指定运行请求的路由。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
*/
|
||||
public void removeRun(String requestId) {
|
||||
if (requestId == null || requestId.isBlank()) {
|
||||
return;
|
||||
}
|
||||
deleteQuietly(requestKey(requestId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除指定恢复令牌的路由。
|
||||
*
|
||||
* @param resumeToken 恢复令牌
|
||||
*/
|
||||
public void removeResumeToken(String resumeToken) {
|
||||
if (resumeToken == null || resumeToken.isBlank()) {
|
||||
return;
|
||||
}
|
||||
deleteQuietly(tokenKey(resumeToken));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前节点 ID。
|
||||
*
|
||||
* @return 当前节点 ID
|
||||
*/
|
||||
public String currentNodeId() {
|
||||
return properties.getInstanceId();
|
||||
}
|
||||
|
||||
/**
|
||||
* 刷新当前节点存活心跳。
|
||||
*
|
||||
* @param ttl 心跳 TTL
|
||||
*/
|
||||
public void heartbeat(Duration ttl) {
|
||||
stringRedisTemplate.opsForValue().set(nodeKey(properties.getInstanceId()), properties.getBootId(), ttl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询指定节点是否仍有存活心跳。
|
||||
*
|
||||
* @param nodeId 节点 ID
|
||||
* @return true 表示节点心跳仍有效
|
||||
*/
|
||||
public boolean isNodeAlive(String nodeId) {
|
||||
return currentNodeBootId(nodeId) != null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询指定节点当前启动代 ID。
|
||||
*
|
||||
* @param nodeId 节点 ID
|
||||
* @return 启动代 ID
|
||||
*/
|
||||
public String currentNodeBootId(String nodeId) {
|
||||
if (nodeId == null || nodeId.isBlank()) {
|
||||
return null;
|
||||
}
|
||||
return stringRedisTemplate.opsForValue().get(nodeKey(nodeId));
|
||||
}
|
||||
|
||||
private String requestKey(String requestId) {
|
||||
return REQUEST_ROUTE_PREFIX + requestId;
|
||||
}
|
||||
|
||||
private String tokenKey(String resumeToken) {
|
||||
return TOKEN_ROUTE_PREFIX + resumeToken;
|
||||
}
|
||||
|
||||
private String nodeKey(String nodeId) {
|
||||
return NODE_HEARTBEAT_PREFIX + nodeId;
|
||||
}
|
||||
|
||||
private AgentRuntimeRoute currentRoute() {
|
||||
AgentRuntimeRoute route = new AgentRuntimeRoute();
|
||||
route.setNodeId(properties.getInstanceId());
|
||||
route.setBootId(properties.getBootId());
|
||||
return route;
|
||||
}
|
||||
|
||||
private String serializeRoute(AgentRuntimeRoute route) {
|
||||
try {
|
||||
return objectMapper.writeValueAsString(route);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalStateException("Agent 运行路由序列化失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private AgentRuntimeRoute deserializeRoute(String value) {
|
||||
try {
|
||||
if (value.trim().startsWith("{")) {
|
||||
return objectMapper.readValue(value, AgentRuntimeRoute.class);
|
||||
}
|
||||
AgentRuntimeRoute legacyRoute = new AgentRuntimeRoute();
|
||||
legacyRoute.setNodeId(value);
|
||||
return legacyRoute;
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new IllegalStateException("Agent 运行路由反序列化失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
private void deleteQuietly(String key) {
|
||||
try {
|
||||
stringRedisTemplate.delete(key);
|
||||
} catch (RuntimeException e) {
|
||||
LOG.warn("清理 Agent 运行态 Redis 路由失败: key={}", key, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import com.easyagents.agent.runtime.event.AgentRuntimeEvent;
|
||||
import com.easyagents.agent.runtime.hitl.AgentResumeToken;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import reactor.core.Disposable;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
@@ -34,6 +36,17 @@ public class AgentRunRegistry {
|
||||
private final Map<String, String> resumeTokenIndex = new ConcurrentHashMap<>();
|
||||
private final Map<String, Set<String>> requestTokens = new ConcurrentHashMap<>();
|
||||
private final Map<String, RunOwner> owners = new ConcurrentHashMap<>();
|
||||
private AgentRuntimeRouteRegistry routeRegistry;
|
||||
|
||||
/**
|
||||
* 设置 Agent 运行态 Redis 路由注册表。
|
||||
*
|
||||
* @param routeRegistry Redis 路由注册表
|
||||
*/
|
||||
@Autowired(required = false)
|
||||
public void setRouteRegistry(AgentRuntimeRouteRegistry routeRegistry) {
|
||||
this.routeRegistry = routeRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册运行态。
|
||||
@@ -57,6 +70,9 @@ public class AgentRunRegistry {
|
||||
throw new BusinessException("当前 Agent 运行请求已存在");
|
||||
}
|
||||
owners.put(context.requestId(), context.owner());
|
||||
if (routeRegistry != null) {
|
||||
routeRegistry.registerRun(context.requestId());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -126,6 +142,9 @@ public class AgentRunRegistry {
|
||||
if (requestId != null && resumeToken != null && !resumeToken.isBlank()) {
|
||||
resumeTokenIndex.put(resumeToken, requestId);
|
||||
requestTokens.computeIfAbsent(requestId, ignored -> ConcurrentHashMap.newKeySet()).add(resumeToken);
|
||||
if (routeRegistry != null) {
|
||||
routeRegistry.registerResumeToken(requestId, resumeToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +166,15 @@ public class AgentRunRegistry {
|
||||
owners.remove(requestId);
|
||||
Set<String> tokens = requestTokens.remove(requestId);
|
||||
if (tokens != null) {
|
||||
tokens.forEach(resumeTokenIndex::remove);
|
||||
tokens.forEach(token -> {
|
||||
resumeTokenIndex.remove(token);
|
||||
if (routeRegistry != null) {
|
||||
routeRegistry.removeResumeToken(token);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (routeRegistry != null) {
|
||||
routeRegistry.removeRun(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,6 +284,9 @@ public class AgentRunRegistry {
|
||||
tokens.remove(resumeToken);
|
||||
}
|
||||
resumeTokenIndex.remove(resumeToken);
|
||||
if (routeRegistry != null) {
|
||||
routeRegistry.removeResumeToken(resumeToken);
|
||||
}
|
||||
AgentResumeToken token = new AgentResumeToken();
|
||||
token.setValue(resumeToken);
|
||||
AgentResumeRequest request = new AgentResumeRequest();
|
||||
|
||||
@@ -19,6 +19,10 @@ import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.enums.AgentToolType;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandAction;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandProducer;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRoute;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
@@ -78,6 +82,10 @@ public class AgentRunService {
|
||||
@Resource
|
||||
private AgentRunRegistry agentRunRegistry;
|
||||
@Resource
|
||||
private AgentRuntimeRouteRegistry agentRuntimeRouteRegistry;
|
||||
@Resource
|
||||
private AgentRuntimeCommandProducer agentRuntimeCommandProducer;
|
||||
@Resource
|
||||
private AgentRunLock agentRunLock;
|
||||
@Resource
|
||||
private AgentHitlPendingService agentHitlPendingService;
|
||||
@@ -231,6 +239,22 @@ public class AgentRunService {
|
||||
}
|
||||
|
||||
private void approveRuntime(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
if (!agentRunRegistry.containsResumeTarget(requestId, resumeToken)) {
|
||||
dispatchRemoteRuntimeCommand(requestId, resumeToken, AgentRuntimeCommandAction.APPROVE, null, operatorId, userId);
|
||||
return;
|
||||
}
|
||||
approveRuntimeLocal(requestId, resumeToken, operatorId, userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在当前节点批准工具执行。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
* @param resumeToken 恢复令牌
|
||||
* @param operatorId 操作人 ID
|
||||
* @param userId 用户 ID
|
||||
*/
|
||||
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.approve(requestId, resumeToken, userId);
|
||||
return;
|
||||
@@ -252,6 +276,23 @@ public class AgentRunService {
|
||||
}
|
||||
|
||||
private void rejectRuntime(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||
if (!agentRunRegistry.containsResumeTarget(requestId, resumeToken)) {
|
||||
dispatchRemoteRuntimeCommand(requestId, resumeToken, AgentRuntimeCommandAction.REJECT, reason, operatorId, userId);
|
||||
return;
|
||||
}
|
||||
rejectRuntimeLocal(requestId, resumeToken, reason, operatorId, userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在当前节点拒绝工具执行。
|
||||
*
|
||||
* @param requestId 请求 ID
|
||||
* @param resumeToken 恢复令牌
|
||||
* @param reason 拒绝原因
|
||||
* @param operatorId 操作人 ID
|
||||
* @param userId 用户 ID
|
||||
*/
|
||||
public void rejectRuntimeLocal(String requestId, String resumeToken, String reason, BigInteger operatorId, String userId) {
|
||||
if (agentRunRegistry.isDraftResumeTarget(requestId, resumeToken)) {
|
||||
agentRunRegistry.reject(requestId, resumeToken, userId, reason);
|
||||
return;
|
||||
@@ -260,6 +301,46 @@ public class AgentRunService {
|
||||
() -> agentHitlPendingService.reject(resumeToken, operatorId, reason));
|
||||
}
|
||||
|
||||
private void dispatchRemoteRuntimeCommand(String requestId,
|
||||
String resumeToken,
|
||||
AgentRuntimeCommandAction action,
|
||||
String reason,
|
||||
BigInteger operatorId,
|
||||
String userId) {
|
||||
String resolvedRequestId = resolveRequestIdForRemoteDispatch(requestId, resumeToken);
|
||||
AgentRuntimeRoute ownerRoute = agentRuntimeRouteRegistry.findOwnerRoute(resolvedRequestId);
|
||||
String ownerNodeId = ownerRoute == null ? null : ownerRoute.getNodeId();
|
||||
if (ownerNodeId == null || ownerNodeId.isBlank()) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
if (ownerNodeId.equals(agentRuntimeRouteRegistry.currentNodeId())) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
if (!agentRuntimeRouteRegistry.isNodeAlive(ownerNodeId)) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
String currentOwnerBootId = agentRuntimeRouteRegistry.currentNodeBootId(ownerNodeId);
|
||||
if (ownerRoute.getBootId() == null || !ownerRoute.getBootId().equals(currentOwnerBootId)) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
if (action == AgentRuntimeCommandAction.APPROVE) {
|
||||
agentRuntimeCommandProducer.sendApprove(ownerNodeId, resolvedRequestId, resumeToken, operatorId, userId);
|
||||
return;
|
||||
}
|
||||
agentRuntimeCommandProducer.sendReject(ownerNodeId, resolvedRequestId, resumeToken, reason, operatorId, userId);
|
||||
}
|
||||
|
||||
private String resolveRequestIdForRemoteDispatch(String requestId, String resumeToken) {
|
||||
if (requestId != null && !requestId.isBlank()) {
|
||||
return requestId;
|
||||
}
|
||||
String resolvedRequestId = agentRuntimeRouteRegistry.findRequestIdByResumeToken(resumeToken);
|
||||
if (resolvedRequestId == null || resolvedRequestId.isBlank()) {
|
||||
throw new BusinessException("Agent 运行节点不可用,请重新发起对话");
|
||||
}
|
||||
return resolvedRequestId;
|
||||
}
|
||||
|
||||
private void startRuntime(Agent agent,
|
||||
String prompt,
|
||||
String requestId,
|
||||
|
||||
@@ -5,6 +5,7 @@ import org.slf4j.LoggerFactory;
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.agent.entity.AgentHitlPending;
|
||||
import tech.easyflow.common.cache.DistributedScheduledLock;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -32,6 +33,7 @@ public class AgentHitlPendingExpirationTask {
|
||||
* 定期将超时 pending 标记为 EXPIRED。
|
||||
*/
|
||||
@Scheduled(fixedDelayString = "${easyflow.agent.runtime.hitl-expire-scan-delay:60000}", initialDelay = 60000L)
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:agent-hitl:expire-pending", leaseSeconds = 300L)
|
||||
public void expirePending() {
|
||||
try {
|
||||
List<AgentHitlPending> expired = pendingService.expirePending(BATCH_SIZE);
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandAction;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandConsumer;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandMessage;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandResultRegistry;
|
||||
import tech.easyflow.agent.runtime.AgentRunService;
|
||||
import tech.easyflow.common.mq.config.MQProperties;
|
||||
import tech.easyflow.common.mq.core.MQMessage;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link AgentRuntimeCommandConsumer} 回归测试。
|
||||
*/
|
||||
public class AgentRuntimeCommandConsumerTest {
|
||||
|
||||
/**
|
||||
* 验证消费者只处理发给当前节点的命令。
|
||||
*
|
||||
* @throws Exception 消息序列化异常
|
||||
*/
|
||||
@Test
|
||||
public void consumerShouldHandleOnlyCurrentNodeCommand() throws Exception {
|
||||
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||
properties.setInstanceId("node-a");
|
||||
MQProperties mqProperties = new MQProperties();
|
||||
mqProperties.getRedis().setChatPersistShardCount(4);
|
||||
RecordingAgentRunService service = new RecordingAgentRunService();
|
||||
RecordingCommandResultRegistry resultRegistry = new RecordingCommandResultRegistry();
|
||||
AgentRuntimeCommandConsumer consumer =
|
||||
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||
|
||||
consumer.handle(List.of(message(command("cmd-1", "node-b")), message(command("cmd-2", "node-a"))));
|
||||
|
||||
Assert.assertEquals(1, service.approveCount);
|
||||
Assert.assertEquals("request-cmd-2", service.lastRequestId);
|
||||
Assert.assertEquals(4, consumer.subscription().getShardCount());
|
||||
Assert.assertFalse(consumer.subscription().isBatchEnabled());
|
||||
Assert.assertEquals("cmd-2", resultRegistry.lastSuccessCommandId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 owner 本机执行失败时写入失败结果,避免 MQ 重试重复消费一次性 token。
|
||||
*
|
||||
* @throws Exception 消息序列化异常
|
||||
*/
|
||||
@Test
|
||||
public void consumerShouldMarkFailureWhenLocalRuntimeFails() throws Exception {
|
||||
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||
properties.setInstanceId("node-a");
|
||||
MQProperties mqProperties = new MQProperties();
|
||||
FailingAgentRunService service = new FailingAgentRunService();
|
||||
RecordingCommandResultRegistry resultRegistry = new RecordingCommandResultRegistry();
|
||||
AgentRuntimeCommandConsumer consumer =
|
||||
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||
|
||||
consumer.handle(List.of(message(command("cmd-1", "node-a"))));
|
||||
|
||||
Assert.assertEquals("cmd-1", resultRegistry.lastFailureCommandId);
|
||||
Assert.assertEquals("runtime missing", resultRegistry.lastFailureMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证成功结果写入失败不会再次执行或改写为失败结果。
|
||||
*
|
||||
* @throws Exception 消息序列化异常
|
||||
*/
|
||||
@Test
|
||||
public void consumerShouldNotMarkFailureWhenSuccessResultWriteFails() throws Exception {
|
||||
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||
properties.setInstanceId("node-a");
|
||||
MQProperties mqProperties = new MQProperties();
|
||||
RecordingAgentRunService service = new RecordingAgentRunService();
|
||||
FailingSuccessResultRegistry resultRegistry = new FailingSuccessResultRegistry();
|
||||
AgentRuntimeCommandConsumer consumer =
|
||||
new AgentRuntimeCommandConsumer(new ObjectMapper(), properties, mqProperties, service, resultRegistry);
|
||||
|
||||
consumer.handle(List.of(message(command("cmd-1", "node-a"))));
|
||||
|
||||
Assert.assertEquals(1, service.approveCount);
|
||||
Assert.assertNull(resultRegistry.lastFailureCommandId);
|
||||
}
|
||||
|
||||
private AgentRuntimeCommandMessage command(String commandId, String targetNodeId) {
|
||||
AgentRuntimeCommandMessage command = new AgentRuntimeCommandMessage();
|
||||
command.setCommandId(commandId);
|
||||
command.setRequestId("request-" + commandId);
|
||||
command.setResumeToken("token-" + commandId);
|
||||
command.setAction(AgentRuntimeCommandAction.APPROVE);
|
||||
command.setOperatorId(BigInteger.ONE);
|
||||
command.setUserId("1");
|
||||
command.setTargetNodeId(targetNodeId);
|
||||
return command;
|
||||
}
|
||||
|
||||
private MQMessage message(AgentRuntimeCommandMessage command) throws Exception {
|
||||
MQMessage message = new MQMessage();
|
||||
message.setMessageId(command.getCommandId());
|
||||
message.setBody(new ObjectMapper().writeValueAsString(command));
|
||||
return message;
|
||||
}
|
||||
|
||||
private static final class RecordingAgentRunService extends AgentRunService {
|
||||
|
||||
private int approveCount;
|
||||
private String lastRequestId;
|
||||
|
||||
@Override
|
||||
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
approveCount++;
|
||||
lastRequestId = requestId;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingCommandResultRegistry extends AgentRuntimeCommandResultRegistry {
|
||||
|
||||
private String lastSuccessCommandId;
|
||||
String lastFailureCommandId;
|
||||
private String lastFailureMessage;
|
||||
|
||||
private RecordingCommandResultRegistry() {
|
||||
super(null, null, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markSuccess(String commandId) {
|
||||
lastSuccessCommandId = commandId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markFailure(String commandId, String message) {
|
||||
lastFailureCommandId = commandId;
|
||||
lastFailureMessage = message;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class FailingAgentRunService extends AgentRunService {
|
||||
|
||||
@Override
|
||||
public void approveRuntimeLocal(String requestId, String resumeToken, BigInteger operatorId, String userId) {
|
||||
throw new RuntimeException("runtime missing");
|
||||
}
|
||||
}
|
||||
|
||||
private static final class FailingSuccessResultRegistry extends RecordingCommandResultRegistry {
|
||||
|
||||
@Override
|
||||
public void markSuccess(String commandId) {
|
||||
super.markSuccess(commandId);
|
||||
throw new RuntimeException("redis down");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandResult;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandResultRegistry;
|
||||
import tech.easyflow.common.web.exceptions.BusinessException;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* {@link AgentRuntimeCommandResultRegistry} 回归测试。
|
||||
*/
|
||||
public class AgentRuntimeCommandResultRegistryTest {
|
||||
|
||||
/**
|
||||
* 验证成功结果可被等待方读取。
|
||||
*/
|
||||
@Test
|
||||
public void waitForResultShouldReturnSuccessResult() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
Mockito.when(valueOperations.get("easyflow:agent:runtime:command-result:cmd-1"))
|
||||
.thenReturn("{\"success\":true,\"message\":\"OK\"}");
|
||||
AgentRuntimeCommandResultRegistry registry = registry(redisTemplate);
|
||||
|
||||
AgentRuntimeCommandResult result = registry.waitForResult("cmd-1");
|
||||
|
||||
Assert.assertTrue(result.isSuccess());
|
||||
Assert.assertEquals("OK", result.getMessage());
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证写入失败结果时使用配置的 TTL。
|
||||
*/
|
||||
@Test
|
||||
public void markFailureShouldWriteResultWithTtl() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
AgentRuntimeProperties properties = properties();
|
||||
AgentRuntimeCommandResultRegistry registry =
|
||||
new AgentRuntimeCommandResultRegistry(redisTemplate, new ObjectMapper(), properties);
|
||||
|
||||
registry.markFailure("cmd-1", "failed");
|
||||
|
||||
Mockito.verify(valueOperations).set(
|
||||
ArgumentMatchers.eq("easyflow:agent:runtime:command-result:cmd-1"),
|
||||
ArgumentMatchers.contains("\"success\":false"),
|
||||
ArgumentMatchers.eq(properties.getCommandResultTtl()));
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证等待超时时抛出明确业务异常。
|
||||
*/
|
||||
@Test
|
||||
public void waitForResultShouldThrowBusinessExceptionWhenTimeout() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
Mockito.when(valueOperations.get(ArgumentMatchers.anyString())).thenReturn(null);
|
||||
AgentRuntimeCommandResultRegistry registry = registry(redisTemplate);
|
||||
|
||||
BusinessException exception = Assert.assertThrows(
|
||||
BusinessException.class,
|
||||
() -> registry.waitForResult("cmd-1"));
|
||||
|
||||
Assert.assertEquals("Agent 运行节点响应超时,请稍后重试", exception.getMessage());
|
||||
}
|
||||
|
||||
private AgentRuntimeCommandResultRegistry registry(StringRedisTemplate redisTemplate) {
|
||||
return new AgentRuntimeCommandResultRegistry(redisTemplate, new ObjectMapper(), properties());
|
||||
}
|
||||
|
||||
private AgentRuntimeProperties properties() {
|
||||
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||
properties.setCommandResultTimeout(Duration.ofMillis(10));
|
||||
properties.setCommandResultTtl(Duration.ofMinutes(5));
|
||||
return properties;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package tech.easyflow.agent.distributed;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.data.redis.core.ValueOperations;
|
||||
import tech.easyflow.agent.config.AgentRuntimeProperties;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
/**
|
||||
* {@link AgentRuntimeRouteRegistry} 回归测试。
|
||||
*/
|
||||
public class AgentRuntimeRouteRegistryTest {
|
||||
|
||||
/**
|
||||
* 验证注册运行态和恢复令牌时写入 Redis 路由。
|
||||
*/
|
||||
@Test
|
||||
public void registerShouldWriteRunAndTokenRoutes() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
AgentRuntimeProperties properties = properties("node-a");
|
||||
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties);
|
||||
|
||||
registry.registerRun("request-1");
|
||||
registry.registerResumeToken("request-1", "token-1");
|
||||
|
||||
Mockito.verify(valueOperations).set(
|
||||
ArgumentMatchers.eq("easyflow:agent:runtime:request:request-1"),
|
||||
ArgumentMatchers.contains("\"nodeId\":\"node-a\""),
|
||||
ArgumentMatchers.eq(Duration.ofHours(24)));
|
||||
Mockito.verify(valueOperations).set(
|
||||
"easyflow:agent:runtime:resume-token:token-1", "request-1", Duration.ofHours(24));
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证运行结束时清理 Redis 路由。
|
||||
*/
|
||||
@Test
|
||||
public void removeShouldDeleteRunAndTokenRoutes() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties("node-a"));
|
||||
|
||||
registry.removeRun("request-1");
|
||||
registry.removeResumeToken("token-1");
|
||||
|
||||
Mockito.verify(redisTemplate).delete("easyflow:agent:runtime:request:request-1");
|
||||
Mockito.verify(redisTemplate).delete("easyflow:agent:runtime:resume-token:token-1");
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证查询 owner 节点和 token 反查请求 ID。
|
||||
*/
|
||||
@Test
|
||||
public void findShouldReadRoutes() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
Mockito.when(valueOperations.get(ArgumentMatchers.eq("easyflow:agent:runtime:request:request-1")))
|
||||
.thenReturn("{\"nodeId\":\"node-a\",\"bootId\":\"boot-a\"}");
|
||||
Mockito.when(valueOperations.get(ArgumentMatchers.eq("easyflow:agent:runtime:resume-token:token-1")))
|
||||
.thenReturn("request-1");
|
||||
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties("node-a"));
|
||||
|
||||
Assert.assertEquals("node-a", registry.findOwnerNode("request-1"));
|
||||
Assert.assertEquals("boot-a", registry.findOwnerRoute("request-1").getBootId());
|
||||
Assert.assertEquals("request-1", registry.findRequestIdByResumeToken("token-1"));
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证节点心跳写入和存活查询。
|
||||
*/
|
||||
@Test
|
||||
public void heartbeatShouldWriteAndReadNodeAliveState() {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
ValueOperations<String, String> valueOperations = Mockito.mock(ValueOperations.class);
|
||||
Mockito.when(redisTemplate.opsForValue()).thenReturn(valueOperations);
|
||||
AgentRuntimeProperties properties = properties("node-a");
|
||||
Mockito.when(valueOperations.get("easyflow:agent:runtime:node:node-a")).thenReturn(properties.getBootId());
|
||||
AgentRuntimeRouteRegistry registry = registry(redisTemplate, properties);
|
||||
|
||||
registry.heartbeat(Duration.ofSeconds(90));
|
||||
|
||||
Mockito.verify(valueOperations).set("easyflow:agent:runtime:node:node-a", properties.getBootId(), Duration.ofSeconds(90));
|
||||
Assert.assertTrue(registry.isNodeAlive("node-a"));
|
||||
Assert.assertEquals(properties.getBootId(), registry.currentNodeBootId("node-a"));
|
||||
}
|
||||
|
||||
private AgentRuntimeProperties properties(String instanceId) {
|
||||
AgentRuntimeProperties properties = new AgentRuntimeProperties();
|
||||
properties.setInstanceId(instanceId);
|
||||
properties.setRouteTtl(Duration.ofHours(24));
|
||||
return properties;
|
||||
}
|
||||
|
||||
private AgentRuntimeRouteRegistry registry(StringRedisTemplate redisTemplate, AgentRuntimeProperties properties) {
|
||||
return new AgentRuntimeRouteRegistry(redisTemplate, properties, new ObjectMapper());
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,9 @@ import tech.easyflow.agent.entity.AgentHitlPending;
|
||||
import tech.easyflow.agent.entity.Agent;
|
||||
import tech.easyflow.agent.entity.AgentKnowledgeBinding;
|
||||
import tech.easyflow.agent.entity.AgentToolBinding;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeCommandProducer;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRoute;
|
||||
import tech.easyflow.agent.distributed.AgentRuntimeRouteRegistry;
|
||||
import tech.easyflow.agent.runtime.event.AgentRunEventRecorder;
|
||||
import tech.easyflow.agent.runtime.hitl.AgentHitlPendingService;
|
||||
import tech.easyflow.agent.runtime.lock.AgentRunLock;
|
||||
@@ -532,6 +535,139 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
Assert.assertEquals(1, pendingService.approveCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证本机存在恢复目标时不投递远程命令。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldNotDispatchRemoteWhenLocalRuntimeExists() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
AgentRunRegistry registry = new AgentRunRegistry();
|
||||
RecordingAgentHitlPendingService pendingService = new RecordingAgentHitlPendingService();
|
||||
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-a");
|
||||
RecordingCommandProducer commandProducer = new RecordingCommandProducer();
|
||||
setField(service, "agentRunRegistry", registry);
|
||||
setField(service, "agentHitlPendingService", pendingService);
|
||||
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||
setField(service, "agentRuntimeCommandProducer", commandProducer);
|
||||
|
||||
registry.register(runContext("request-local-approve", "session-local-approve", true));
|
||||
registry.registerResumeToken("request-local-approve", "token-local-approve");
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
"request-local-approve", "token-local-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(1, pendingService.approveCount);
|
||||
Assert.assertEquals(0, commandProducer.approveCount);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证本机无运行态但 Redis owner 存在时投递远程命令。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldDispatchRemoteWhenOwnerIsRemoteNode() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||
routeRegistry.requestIdByToken = "request-remote-approve";
|
||||
routeRegistry.ownerNode = "node-a";
|
||||
routeRegistry.ownerBootId = "boot-a";
|
||||
routeRegistry.currentOwnerBootId = "boot-a";
|
||||
routeRegistry.nodeAlive = true;
|
||||
RecordingCommandProducer commandProducer = new RecordingCommandProducer();
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||
setField(service, "agentRuntimeCommandProducer", commandProducer);
|
||||
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
null, "token-remote-approve", BigInteger.ONE, "1");
|
||||
|
||||
Assert.assertEquals(1, commandProducer.approveCount);
|
||||
Assert.assertEquals("node-a", commandProducer.lastTargetNodeId);
|
||||
Assert.assertEquals("request-remote-approve", commandProducer.lastRequestId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 owner 缺失时明确失败。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldFailWhenOwnerRouteMissing() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||
routeRegistry.requestIdByToken = "request-missing-owner";
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||
|
||||
try {
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
null, "token-missing-owner", BigInteger.ONE, "1");
|
||||
Assert.fail("expected BusinessException");
|
||||
} catch (Exception e) {
|
||||
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 owner 重启后启动代不匹配会明确失败。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldFailWhenOwnerBootIdChanged() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||
routeRegistry.requestIdByToken = "request-restarted-owner";
|
||||
routeRegistry.ownerNode = "node-a";
|
||||
routeRegistry.ownerBootId = "boot-old";
|
||||
routeRegistry.currentOwnerBootId = "boot-new";
|
||||
routeRegistry.nodeAlive = true;
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||
|
||||
try {
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
null, "token-restarted-owner", BigInteger.ONE, "1");
|
||||
Assert.fail("expected BusinessException");
|
||||
} catch (Exception e) {
|
||||
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证 owner 路由存在但节点心跳缺失时明确失败。
|
||||
*
|
||||
* @throws Exception 反射调用失败时抛出
|
||||
*/
|
||||
@Test
|
||||
public void approveShouldFailWhenOwnerNodeHeartbeatMissing() throws Exception {
|
||||
AgentRunService service = new AgentRunService();
|
||||
RecordingRouteRegistry routeRegistry = new RecordingRouteRegistry("node-b");
|
||||
routeRegistry.requestIdByToken = "request-offline-owner";
|
||||
routeRegistry.ownerNode = "node-a";
|
||||
routeRegistry.nodeAlive = false;
|
||||
setField(service, "agentRunRegistry", new AgentRunRegistry());
|
||||
setField(service, "agentRuntimeRouteRegistry", routeRegistry);
|
||||
setField(service, "agentRuntimeCommandProducer", new RecordingCommandProducer());
|
||||
|
||||
try {
|
||||
invoke(service, "approveRuntime",
|
||||
new Class<?>[]{String.class, String.class, BigInteger.class, String.class},
|
||||
null, "token-offline-owner", BigInteger.ONE, "1");
|
||||
Assert.fail("expected BusinessException");
|
||||
} catch (Exception e) {
|
||||
Assert.assertTrue(rootCause(e) instanceof BusinessException);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证清理草稿会话只清草稿 store,不触碰 MySQL pending 清理。
|
||||
*
|
||||
@@ -785,6 +921,72 @@ public class AgentRunServiceDraftAndHitlTest {
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingRouteRegistry extends AgentRuntimeRouteRegistry {
|
||||
|
||||
private final String currentNodeId;
|
||||
private String ownerNode;
|
||||
private String ownerBootId;
|
||||
private String currentOwnerBootId;
|
||||
private String requestIdByToken;
|
||||
private boolean nodeAlive;
|
||||
|
||||
private RecordingRouteRegistry(String currentNodeId) {
|
||||
super(null, null, null);
|
||||
this.currentNodeId = currentNodeId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String findOwnerNode(String requestId) {
|
||||
return ownerNode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentRuntimeRoute findOwnerRoute(String requestId) {
|
||||
AgentRuntimeRoute route = new AgentRuntimeRoute();
|
||||
route.setNodeId(ownerNode);
|
||||
route.setBootId(ownerBootId);
|
||||
return route;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String findRequestIdByResumeToken(String resumeToken) {
|
||||
return requestIdByToken;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String currentNodeId() {
|
||||
return currentNodeId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNodeAlive(String nodeId) {
|
||||
return nodeAlive;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String currentNodeBootId(String nodeId) {
|
||||
return currentOwnerBootId;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingCommandProducer extends AgentRuntimeCommandProducer {
|
||||
|
||||
private int approveCount;
|
||||
private String lastTargetNodeId;
|
||||
private String lastRequestId;
|
||||
|
||||
@Override
|
||||
public void sendApprove(String targetNodeId,
|
||||
String requestId,
|
||||
String resumeToken,
|
||||
BigInteger operatorId,
|
||||
String userId) {
|
||||
approveCount++;
|
||||
lastTargetNodeId = targetNodeId;
|
||||
lastRequestId = requestId;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RecordingAgentRuntimeFactory implements AgentRuntimeFactory {
|
||||
|
||||
private final AgentRuntime runtime;
|
||||
|
||||
@@ -131,5 +131,11 @@
|
||||
<version>${junit.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.12.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -5,11 +5,13 @@ import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.ComponentScan;
|
||||
import tech.easyflow.ai.documentimport.task.DocumentImportParseMonitorProperties;
|
||||
import tech.easyflow.ai.documentimport.task.DocumentImportStatusBroadcastProperties;
|
||||
|
||||
@MapperScan("tech.easyflow.ai.mapper")
|
||||
@ComponentScan("tech.easyflow.ai")
|
||||
@EnableConfigurationProperties({
|
||||
DocumentImportParseMonitorProperties.class,
|
||||
DocumentImportStatusBroadcastProperties.class,
|
||||
RagHealthProperties.class
|
||||
})
|
||||
@AutoConfiguration
|
||||
|
||||
@@ -2,6 +2,7 @@ package tech.easyflow.ai.documentimport.task;
|
||||
|
||||
import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.common.cache.DistributedScheduledLock;
|
||||
|
||||
/**
|
||||
* 知识库文档解析任务收敛器。
|
||||
@@ -27,6 +28,7 @@ public class DocumentImportParseMonitor {
|
||||
fixedDelayString = "${easyflow.ai.document-import.parse-monitor.fixed-delay:10000}",
|
||||
initialDelayString = "${easyflow.ai.document-import.parse-monitor.initial-delay:10000}"
|
||||
)
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:document-import:parse-monitor", leaseSeconds = 300L)
|
||||
public void reconcileRunningParseTasks() {
|
||||
appService.monitorRunningParseTasks();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package tech.easyflow.ai.documentimport.task;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.data.redis.connection.Message;
|
||||
import org.springframework.data.redis.connection.MessageListener;
|
||||
import org.springframework.data.redis.connection.RedisConnectionFactory;
|
||||
import org.springframework.data.redis.listener.ChannelTopic;
|
||||
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
/**
|
||||
* 文档导入状态 Redis 广播配置。
|
||||
*/
|
||||
@Configuration
|
||||
public class DocumentImportStatusBroadcastConfig {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(DocumentImportStatusBroadcastConfig.class);
|
||||
|
||||
/**
|
||||
* 创建文档导入状态广播监听容器。
|
||||
*
|
||||
* @param connectionFactory Redis 连接工厂
|
||||
* @param streamService 文档导入状态流服务
|
||||
* @param properties 文档导入监控配置
|
||||
* @return Redis 消息监听容器
|
||||
*/
|
||||
@Bean
|
||||
public RedisMessageListenerContainer documentImportStatusListenerContainer(
|
||||
RedisConnectionFactory connectionFactory,
|
||||
DocumentImportTaskStatusStreamService streamService,
|
||||
DocumentImportStatusBroadcastProperties properties
|
||||
) {
|
||||
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
|
||||
container.setConnectionFactory(connectionFactory);
|
||||
container.addMessageListener(
|
||||
new DocumentImportStatusMessageListener(streamService),
|
||||
new ChannelTopic(properties.getStatusBroadcastChannel())
|
||||
);
|
||||
return container;
|
||||
}
|
||||
|
||||
/**
|
||||
* 文档导入状态广播监听器。
|
||||
*/
|
||||
private static final class DocumentImportStatusMessageListener implements MessageListener {
|
||||
|
||||
private final DocumentImportTaskStatusStreamService streamService;
|
||||
|
||||
/**
|
||||
* 创建监听器。
|
||||
*
|
||||
* @param streamService 文档导入状态流服务
|
||||
*/
|
||||
private DocumentImportStatusMessageListener(DocumentImportTaskStatusStreamService streamService) {
|
||||
this.streamService = streamService;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 Redis 广播消息。
|
||||
*
|
||||
* @param message 消息
|
||||
* @param pattern 订阅模式
|
||||
*/
|
||||
@Override
|
||||
public void onMessage(Message message, byte[] pattern) {
|
||||
String payload = new String(message.getBody(), StandardCharsets.UTF_8);
|
||||
try {
|
||||
streamService.publishLocal(new BigInteger(payload));
|
||||
} catch (RuntimeException e) {
|
||||
LOG.warn("处理文档导入状态广播失败: payload={}", payload, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package tech.easyflow.ai.documentimport.task;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
/**
|
||||
* 文档导入状态广播配置。
|
||||
*/
|
||||
@ConfigurationProperties(prefix = "easyflow.ai.document-import")
|
||||
public class DocumentImportStatusBroadcastProperties {
|
||||
|
||||
private String statusBroadcastChannel = "easyflow:document-import:status";
|
||||
|
||||
/**
|
||||
* 获取文档导入状态广播通道。
|
||||
*
|
||||
* @return Redis 广播通道
|
||||
*/
|
||||
public String getStatusBroadcastChannel() {
|
||||
return statusBroadcastChannel;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置文档导入状态广播通道。
|
||||
*
|
||||
* @param statusBroadcastChannel Redis 广播通道
|
||||
*/
|
||||
public void setStatusBroadcastChannel(String statusBroadcastChannel) {
|
||||
if (statusBroadcastChannel == null || statusBroadcastChannel.trim().isEmpty()) {
|
||||
this.statusBroadcastChannel = "easyflow:document-import:status";
|
||||
return;
|
||||
}
|
||||
this.statusBroadcastChannel = statusBroadcastChannel.trim();
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package tech.easyflow.ai.documentimport.task;
|
||||
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.support.TransactionSynchronization;
|
||||
@@ -43,6 +44,12 @@ public class DocumentImportTaskStatusStreamService {
|
||||
@Resource(name = "sseThreadPool")
|
||||
private ThreadPoolTaskExecutor sseThreadPool;
|
||||
|
||||
@Resource
|
||||
private StringRedisTemplate stringRedisTemplate;
|
||||
|
||||
@Resource
|
||||
private DocumentImportStatusBroadcastProperties statusBroadcastProperties;
|
||||
|
||||
/**
|
||||
* 订阅知识库文档任务状态流。
|
||||
*
|
||||
@@ -75,7 +82,7 @@ public class DocumentImportTaskStatusStreamService {
|
||||
if (documentId == null) {
|
||||
return;
|
||||
}
|
||||
Runnable publishAction = () -> publishNow(documentId);
|
||||
Runnable publishAction = () -> publishStatusChange(documentId);
|
||||
if (TransactionSynchronizationManager.isSynchronizationActive()
|
||||
&& TransactionSynchronizationManager.isActualTransactionActive()) {
|
||||
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
|
||||
@@ -89,7 +96,22 @@ public class DocumentImportTaskStatusStreamService {
|
||||
publishAction.run();
|
||||
}
|
||||
|
||||
private void publishNow(BigInteger documentId) {
|
||||
/**
|
||||
* 处理 Redis 广播收到的文档状态变更。
|
||||
*
|
||||
* @param documentId 文档 ID
|
||||
*/
|
||||
public void publishLocal(BigInteger documentId) {
|
||||
publishNow(documentId);
|
||||
}
|
||||
|
||||
private void publishStatusChange(BigInteger documentId) {
|
||||
// 先推送本机连接,降低单机部署和广播链路延迟。
|
||||
publishNow(documentId);
|
||||
stringRedisTemplate.convertAndSend(statusBroadcastProperties.getStatusBroadcastChannel(), documentId.toString());
|
||||
}
|
||||
|
||||
void publishNow(BigInteger documentId) {
|
||||
Document document = documentMapper.selectOneById(documentId);
|
||||
if (document == null || document.getCollectionId() == null) {
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package tech.easyflow.ai.documentimport.task;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.data.redis.core.StringRedisTemplate;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
import tech.easyflow.ai.entity.Document;
|
||||
import tech.easyflow.ai.mapper.DocumentMapper;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.math.BigInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* {@link DocumentImportTaskStatusStreamService} 回归测试。
|
||||
*/
|
||||
public class DocumentImportTaskStatusStreamServiceTest {
|
||||
|
||||
/**
|
||||
* 验证文档状态变更会向 Redis 广播文档 ID。
|
||||
*
|
||||
* @throws Exception 反射注入异常
|
||||
*/
|
||||
@Test
|
||||
public void publishAfterCommitShouldBroadcastDocumentId() throws Exception {
|
||||
StringRedisTemplate redisTemplate = Mockito.mock(StringRedisTemplate.class);
|
||||
DocumentImportTaskStatusStreamService service = new DocumentImportTaskStatusStreamService();
|
||||
setField(service, "documentMapper", mockDocumentMapper());
|
||||
setField(service, "sseThreadPool", directExecutor());
|
||||
setField(service, "stringRedisTemplate", redisTemplate);
|
||||
setField(service, "statusBroadcastProperties", statusBroadcastProperties());
|
||||
|
||||
service.publishAfterCommit(BigInteger.valueOf(101));
|
||||
|
||||
Mockito.verify(redisTemplate).convertAndSend("easyflow:document-import:test-status", "101");
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证收到 Redis 广播后会重新查询文档状态。
|
||||
*
|
||||
* @throws Exception 反射注入异常
|
||||
*/
|
||||
@Test
|
||||
public void publishLocalShouldReloadDocumentStatus() throws Exception {
|
||||
AtomicReference<BigInteger> selectedIdRef = new AtomicReference<BigInteger>();
|
||||
DocumentImportTaskStatusStreamService service = new DocumentImportTaskStatusStreamService();
|
||||
setField(service, "documentMapper", mockDocumentMapper(selectedIdRef));
|
||||
setField(service, "sseThreadPool", directExecutor());
|
||||
setField(service, "stringRedisTemplate", Mockito.mock(StringRedisTemplate.class));
|
||||
setField(service, "statusBroadcastProperties", statusBroadcastProperties());
|
||||
|
||||
service.publishLocal(BigInteger.valueOf(202));
|
||||
|
||||
Assert.assertEquals(BigInteger.valueOf(202), selectedIdRef.get());
|
||||
}
|
||||
|
||||
private DocumentImportStatusBroadcastProperties statusBroadcastProperties() {
|
||||
DocumentImportStatusBroadcastProperties properties = new DocumentImportStatusBroadcastProperties();
|
||||
properties.setStatusBroadcastChannel("easyflow:document-import:test-status");
|
||||
return properties;
|
||||
}
|
||||
|
||||
private DocumentMapper mockDocumentMapper() {
|
||||
return mockDocumentMapper(new AtomicReference<BigInteger>());
|
||||
}
|
||||
|
||||
private DocumentMapper mockDocumentMapper(AtomicReference<BigInteger> selectedIdRef) {
|
||||
DocumentMapper mapper = Mockito.mock(DocumentMapper.class);
|
||||
Mockito.when(mapper.selectOneById(ArgumentMatchers.any())).thenAnswer(invocation -> {
|
||||
Object id = invocation.getArgument(0);
|
||||
selectedIdRef.set((BigInteger) id);
|
||||
Document document = new Document();
|
||||
document.setId((BigInteger) id);
|
||||
document.setCollectionId(BigInteger.valueOf(1));
|
||||
return document;
|
||||
});
|
||||
return mapper;
|
||||
}
|
||||
|
||||
private ThreadPoolTaskExecutor directExecutor() {
|
||||
ThreadPoolTaskExecutor executor = Mockito.mock(ThreadPoolTaskExecutor.class);
|
||||
Mockito.doAnswer(invocation -> {
|
||||
Runnable runnable = invocation.getArgument(0);
|
||||
runnable.run();
|
||||
return null;
|
||||
}).when(executor).execute(ArgumentMatchers.any(Runnable.class));
|
||||
return executor;
|
||||
}
|
||||
|
||||
private void setField(Object target, String fieldName, Object value) throws Exception {
|
||||
Field field = DocumentImportTaskStatusStreamService.class.getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
field.set(target, value);
|
||||
}
|
||||
}
|
||||
@@ -4,19 +4,33 @@ import org.springframework.scheduling.annotation.Scheduled;
|
||||
import org.springframework.stereotype.Component;
|
||||
import tech.easyflow.chatlog.config.ChatSyncProperties;
|
||||
import tech.easyflow.chatlog.service.ChatSyncService;
|
||||
import tech.easyflow.common.cache.DistributedScheduledLock;
|
||||
|
||||
/**
|
||||
* 聊天记录同步定时任务。
|
||||
*/
|
||||
@Component
|
||||
public class ChatSyncScheduler {
|
||||
|
||||
private final ChatSyncService chatSyncService;
|
||||
private final ChatSyncProperties syncProperties;
|
||||
|
||||
/**
|
||||
* 创建聊天记录同步定时任务。
|
||||
*
|
||||
* @param chatSyncService 聊天同步服务
|
||||
* @param syncProperties 同步配置
|
||||
*/
|
||||
public ChatSyncScheduler(ChatSyncService chatSyncService, ChatSyncProperties syncProperties) {
|
||||
this.chatSyncService = chatSyncService;
|
||||
this.syncProperties = syncProperties;
|
||||
}
|
||||
|
||||
/**
|
||||
* 同步聊天会话摘要。
|
||||
*/
|
||||
@Scheduled(fixedDelayString = "${easyflow.chat.sync.fixed-delay:30000}", initialDelay = 10000L)
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:chat-sync:sessions", leaseSeconds = 300L)
|
||||
public void syncSessions() {
|
||||
if (!syncProperties.isEnabled()) {
|
||||
return;
|
||||
@@ -24,7 +38,11 @@ public class ChatSyncScheduler {
|
||||
chatSyncService.syncSessions();
|
||||
}
|
||||
|
||||
/**
|
||||
* 同步聊天日志明细。
|
||||
*/
|
||||
@Scheduled(fixedDelayString = "${easyflow.chat.sync.fixed-delay:30000}", initialDelay = 15000L)
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:chat-sync:logs", leaseSeconds = 300L)
|
||||
public void syncLogs() {
|
||||
if (!syncProperties.isEnabled()) {
|
||||
return;
|
||||
@@ -32,7 +50,11 @@ public class ChatSyncScheduler {
|
||||
chatSyncService.syncLogs();
|
||||
}
|
||||
|
||||
/**
|
||||
* 修复近期聊天日志同步缺口。
|
||||
*/
|
||||
@Scheduled(cron = "0 15 3 * * *")
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:chat-sync:repair-logs", leaseSeconds = 300L)
|
||||
public void repairLogs() {
|
||||
if (!syncProperties.isEnabled()) {
|
||||
return;
|
||||
@@ -40,7 +62,11 @@ public class ChatSyncScheduler {
|
||||
chatSyncService.repairLogs();
|
||||
}
|
||||
|
||||
/**
|
||||
* 维护聊天日志 MySQL 分表。
|
||||
*/
|
||||
@Scheduled(cron = "0 0 2 * * *")
|
||||
@DistributedScheduledLock(key = "easyflow:schedule:chat-sync:maintain-mysql-tables", leaseSeconds = 300L)
|
||||
public void maintainMysqlTables() {
|
||||
chatSyncService.maintainMysqlTables();
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ easyflow:
|
||||
redis:
|
||||
database: 1
|
||||
stream-prefix: easyflow:mq
|
||||
consumer-instance-id: ${EASYFLOW_INSTANCE_ID:${HOSTNAME:${random.uuid}}}
|
||||
chat-persist-shard-count: 4
|
||||
consumer-batch-size: 200
|
||||
consumer-block-timeout: 2000ms
|
||||
@@ -74,11 +75,19 @@ easyflow:
|
||||
validate-on-migrate: true
|
||||
storage:
|
||||
type: xFileStorage
|
||||
agent:
|
||||
runtime:
|
||||
instance-id: ${EASYFLOW_INSTANCE_ID:${HOSTNAME:${random.uuid}}}
|
||||
route-ttl: 24h
|
||||
command-topic-prefix: easyflow:agent-runtime-command
|
||||
command-result-timeout: 5s
|
||||
command-result-ttl: 5m
|
||||
ai:
|
||||
rag:
|
||||
health:
|
||||
cache-ttl: 5s
|
||||
document-import:
|
||||
status-broadcast-channel: easyflow:document-import:status
|
||||
parse-monitor:
|
||||
fixed-delay: 10000
|
||||
initial-delay: 10000
|
||||
|
||||
@@ -106,14 +106,15 @@ easyflow:
|
||||
redis:
|
||||
database: 1
|
||||
stream-prefix: easyflow:mq
|
||||
consumer-instance-id: ${EASYFLOW_INSTANCE_ID:${HOSTNAME:${random.uuid}}}
|
||||
chat-persist-shard-count: 4
|
||||
consumer-batch-size: 200
|
||||
consumer-block-timeout: 2000ms
|
||||
pending-claim-idle: 60000ms
|
||||
max-retry: 16
|
||||
consumer-executor:
|
||||
core-size: 4
|
||||
max-size: 12
|
||||
core-size: 16
|
||||
max-size: 24
|
||||
queue-capacity: 64
|
||||
keep-alive-seconds: 60
|
||||
pool:
|
||||
@@ -148,6 +149,13 @@ easyflow:
|
||||
access-key-secret: xxx
|
||||
app-key: xxx
|
||||
voice: siyue
|
||||
agent:
|
||||
runtime:
|
||||
instance-id: ${EASYFLOW_INSTANCE_ID:${HOSTNAME:${random.uuid}}}
|
||||
route-ttl: 24h
|
||||
command-topic-prefix: easyflow:agent-runtime-command
|
||||
command-result-timeout: 5s
|
||||
command-result-ttl: 5m
|
||||
login:
|
||||
# 放行接口路径
|
||||
excludes: /api/v1/auth/**, /static/**, /userCenter/auth/**, /userCenter/public/**
|
||||
@@ -169,6 +177,7 @@ easyflow:
|
||||
health:
|
||||
cache-ttl: 5s
|
||||
document-import:
|
||||
status-broadcast-channel: easyflow:document-import:status
|
||||
parse-monitor:
|
||||
fixed-delay: 10000
|
||||
initial-delay: 10000
|
||||
|
||||
Reference in New Issue
Block a user