Compare commits

...

11 Commits

Author SHA1 Message Date
b6213d0933 feat: 增强知识库分块策略流程
- 增加导入分析预览提交与预览态缓存键

- 支持知识库分块策略配置与分块预览

- 重构知识库导入与确认导入前端流程
2026-03-29 17:27:12 +08:00
22ceabff96 feat: 增加工作流和知识库三级权限
- 抽取统一资源访问骨架与部门可见范围判断

- 接入工作流和知识库的 READ/MANAGE 权限校验

- 增加可见范围配置与只读态前端交互
2026-03-29 17:25:55 +08:00
f49d94e2fe feat: 增加分类权限控制
- 新增角色分类授权模型与超级管理员配置接口

- 接入助手、插件、工作流、知识库、素材的分类可见性过滤

- 增加角色页分类权限树与插件多分类可见性支持
2026-03-29 17:16:37 +08:00
aaf4c61ff8 feat: 新增统一模型网关与模型管理工作区
- 新增 OpenAI 兼容统一模型调用链路、模型发布配置与批量发布能力

- 重构模型管理页面入口与统一网关工作区,更新服务商 logo 资源与模型 ID 文案

- 收口全新库初始化脚本,仅保留服务商种子并整理统一网关 migration
2026-03-26 20:48:18 +08:00
b777cb3641 chore: 移除系统模块启动调试输出
- 删除 module-system 自动配置中的控制台打印
2026-03-24 18:39:37 +08:00
c78db961c5 fix: 修复管理端静态资源基础路径
- 统一内置品牌资源与空状态图片的 BASE_URL 解析

- 应用启动时自动归一化历史偏好里的内置资源路径

- 多个空状态组件改为复用公共资源地址工具
2026-03-24 18:39:16 +08:00
da536ea742 fix: 统一上传响应与表单校验处理
- 上传组件统一解析后端响应并暴露错误事件

- AI 资源、模型提供商和工作流表单补齐程序化字段校验同步

- 修正 MinIO 对外访问域名配置
2026-03-24 18:38:42 +08:00
799174406e fix: 修复插件保存后的分类关联
- 插件保存接口返回实体以便前端拿到真实插件 ID

- 分类关联更新改为按差异增删并补充事务保护

- 新建插件后缺失 ID 时明确抛出错误
2026-03-24 18:38:09 +08:00
6e1bd73cd8 feat: 支持系统账号批量操作
- 新增账号批量删除和批量重置密码接口及结果返回

- 用户列表增加批量操作工具栏与结果提示

- 账号删除切换为逻辑删除语义
2026-03-24 18:37:32 +08:00
d510034abb feat: 新增管理端工作台总览
- 新增 Dashboard 统计接口、菜单迁移与权限点

- 管理端工作台页面切换为真实概览数据和趋势图

- 默认首页切换到工作台
2026-03-24 18:36:54 +08:00
b1a16ccf18 feat: 支持通过Flyway自动初始化数据库
- 将 starter 初始化脚本迁移到 db/migration,并保留 V1-V3 作为首批迁移

- 清理旧 sql/initdb 挂载与历史分段 SQL,避免 Docker 启动时重复导库

- 更新 README、应用配置和中间件编排,统一空库启动方式
2026-03-24 12:36:44 +08:00
257 changed files with 13900 additions and 3897 deletions

View File

@@ -28,9 +28,14 @@ EasyFlow 是一个面向企业场景的 Java AI 应用开发平台,提供智
## 快速启动(开发环境) ## 快速启动(开发环境)
### 1. 初始化数据库 ### 1. 初始化数据库
在 MySQL 中导入 只需要提前创建空库,例如
- `sql/01-easyflow-v2.ddl.sql`
- `sql/02-easyflow-v2.data.sql` ```sql
CREATE DATABASE easyflow CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci;
```
表结构、Quartz 表和基础数据会在后端启动时由 Flyway 自动迁移完成。
当前初始化脚本仅保留 3 段:`V1__quartz.sql``V2__easyflow_schema.sql``V3__easyflow_seed.sql`
### 2. 启动后端 ### 2. 启动后端
在项目根目录执行: 在项目根目录执行:
@@ -42,6 +47,8 @@ java -jar easyflow-starter/easyflow-starter-all/target/easyflow-starter-all-0.0.
默认端口:`8111`(见 `easyflow-starter/easyflow-starter-all/src/main/resources/application.yml`)。 默认端口:`8111`(见 `easyflow-starter/easyflow-starter-all/src/main/resources/application.yml`)。
首次连接空库时,应用会自动执行 `db/migration` 下的迁移脚本并初始化默认管理员、菜单、角色和模型服务商数据。
### 3. 启动前端 ### 3. 启动前端
管理后台: 管理后台:
@@ -59,7 +66,7 @@ pnpm install
pnpm dev pnpm dev
``` ```
默认测试账号:`admin / 123456` 默认测试账号:`admin / Easy@2026`
## 后端 Jar 包构建与部署 ## 后端 Jar 包构建与部署
@@ -78,6 +85,8 @@ mvn -DskipTests -Dmaven.javadoc.skip=true clean package
java -jar easyflow-starter/easyflow-starter-all/target/easyflow-starter-all-0.0.1.jar --spring.profiles.active=prod java -jar easyflow-starter/easyflow-starter-all/target/easyflow-starter-all-0.0.1.jar --spring.profiles.active=prod
``` ```
生产环境同样只需要保证目标数据库已创建为空库Flyway 会在应用启动时自动完成迁移。
可通过环境变量覆盖关键配置(示例): 可通过环境变量覆盖关键配置(示例):
```bash ```bash

View File

@@ -13,18 +13,17 @@ services:
- --max_connections=500 - --max_connections=500
- --sql_mode=STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION - --sql_mode=STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION
environment: environment:
TZ: ${TZ:-Asia/Shanghai} TZ: Asia/Shanghai
MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD:-root} MYSQL_ROOT_PASSWORD: root
MYSQL_DATABASE: ${MYSQL_DATABASE:-easyflow} MYSQL_DATABASE: easyflow
MYSQL_USER: ${MYSQL_USER:-easyflow} MYSQL_USER: easyflow
MYSQL_PASSWORD: ${MYSQL_PASSWORD:-123456} MYSQL_PASSWORD: "123456"
ports: ports:
- "${MYSQL_PORT:-3306}:3306" - "3306:3306"
volumes: volumes:
- ./data/mysql:/var/lib/mysql - ./data/mysql:/var/lib/mysql
- ./easyflow/sql/initdb:/docker-entrypoint-initdb.d:ro
healthcheck: healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "127.0.0.1", "-uroot", "-p${MYSQL_ROOT_PASSWORD:-root}"] test: ["CMD", "mysqladmin", "ping", "-h", "127.0.0.1", "-uroot", "-proot"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 10 retries: 10
@@ -34,13 +33,13 @@ services:
image: swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/kubesphere/redis:7.2.4-alpine-linuxarm64 image: swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/kubesphere/redis:7.2.4-alpine-linuxarm64
container_name: easyflow-redis container_name: easyflow-redis
restart: unless-stopped restart: unless-stopped
command: ["redis-server", "--appendonly", "yes", "--requirepass", "${REDIS_PASSWORD:-123456}"] command: ["redis-server", "--appendonly", "yes", "--requirepass", "123456"]
ports: ports:
- "${REDIS_PORT:-6379}:6379" - "6379:6379"
volumes: volumes:
- ./data/redis:/data - ./data/redis:/data
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD:-123456}", "ping"] test: ["CMD", "redis-cli", "-a", "123456", "ping"]
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 10 retries: 10
@@ -51,7 +50,7 @@ services:
container_name: easyflow-etcd container_name: easyflow-etcd
restart: unless-stopped restart: unless-stopped
environment: environment:
TZ: ${TZ:-Asia/Shanghai} TZ: Asia/Shanghai
ETCD_AUTO_COMPACTION_MODE: revision ETCD_AUTO_COMPACTION_MODE: revision
ETCD_AUTO_COMPACTION_RETENTION: "1000" ETCD_AUTO_COMPACTION_RETENTION: "1000"
ETCD_QUOTA_BACKEND_BYTES: "4294967296" ETCD_QUOTA_BACKEND_BYTES: "4294967296"
@@ -70,12 +69,12 @@ services:
restart: unless-stopped restart: unless-stopped
command: server /data --address ":9000" --console-address ":9001" command: server /data --address ":9000" --console-address ":9001"
environment: environment:
TZ: ${TZ:-Asia/Shanghai} TZ: Asia/Shanghai
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-easyflowadmin} MINIO_ROOT_USER: easyflowadmin
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-easyflowadmin123} MINIO_ROOT_PASSWORD: easyflowadmin123
ports: ports:
- "${MINIO_PORT:-9000}:9000" - "9000:9000"
- "${MINIO_CONSOLE_PORT:-9001}:9001" - "9001:9001"
volumes: volumes:
- ./data/minio:/data - ./data/minio:/data
@@ -86,12 +85,12 @@ services:
depends_on: depends_on:
- minio - minio
environment: environment:
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-easyflowadmin} MINIO_ROOT_USER: easyflowadmin
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-easyflowadmin123} MINIO_ROOT_PASSWORD: easyflowadmin123
MINIO_ENDPOINT: ${MINIO_ENDPOINT:-http://minio:9000} MINIO_ENDPOINT: http://minio:9000
MINIO_BUCKETS: ${MINIO_BUCKETS:-easyflow,milvus} MINIO_BUCKETS: easyflow,milvus
MINIO_PUBLIC_BUCKETS: ${MINIO_PUBLIC_BUCKETS:-easyflow} MINIO_PUBLIC_BUCKETS: easyflow
MINIO_ALIAS: ${MINIO_ALIAS:-local} MINIO_ALIAS: local
volumes: volumes:
- ./scripts/minio/init-minio.sh:/scripts/init-minio.sh:ro - ./scripts/minio/init-minio.sh:/scripts/init-minio.sh:ro
entrypoint: ["/bin/sh", "/scripts/init-minio.sh"] entrypoint: ["/bin/sh", "/scripts/init-minio.sh"]
@@ -106,11 +105,11 @@ services:
environment: environment:
ETCD_ENDPOINTS: etcd:2379 ETCD_ENDPOINTS: etcd:2379
COMMON_STORAGETYPE: minio COMMON_STORAGETYPE: minio
MINIO_ADDRESS: ${MILVUS_MINIO_ADDRESS:-minio:9000} MINIO_ADDRESS: minio:9000
MINIO_ACCESS_KEY_ID: ${MINIO_ROOT_USER:-easyflowadmin} MINIO_ACCESS_KEY_ID: easyflowadmin
MINIO_SECRET_ACCESS_KEY: ${MINIO_ROOT_PASSWORD:-easyflowadmin123} MINIO_SECRET_ACCESS_KEY: easyflowadmin123
MINIO_USE_SSL: "false" MINIO_USE_SSL: "false"
MINIO_BUCKET_NAME: ${MILVUS_MINIO_BUCKET:-milvus} MINIO_BUCKET_NAME: milvus
depends_on: depends_on:
etcd: etcd:
condition: service_started condition: service_started
@@ -119,7 +118,7 @@ services:
minio-init: minio-init:
condition: service_completed_successfully condition: service_completed_successfully
ports: ports:
- "${MILVUS_GRPC_PORT:-19530}:19530" - "19530:19530"
- "${MILVUS_HTTP_PORT:-9091}:9091" - "9091:9091"
volumes: volumes:
- ./data/milvus:/var/lib/milvus - ./data/milvus:/var/lib/milvus

View File

@@ -61,10 +61,6 @@ services:
MYSQL_DATABASE: easyflow MYSQL_DATABASE: easyflow
ports: ports:
- "3306:3306" - "3306:3306"
volumes:
- ./sql:/docker-entrypoint-initdb.d
# 这个命令在默认 entrypoint 运行前修复权限
entrypoint: sh -c "chown -R mysql:mysql /docker-entrypoint-initdb.d && docker-entrypoint.sh mysqld"
networks: networks:
- easyflow-net - easyflow-net
healthcheck: healthcheck:

View File

@@ -1,11 +1,20 @@
package tech.easyflow.admin.controller.ai; package tech.easyflow.admin.controller.ai;
import com.mybatisflex.core.query.QueryWrapper;
import tech.easyflow.ai.entity.BotCategory; import tech.easyflow.ai.entity.BotCategory;
import tech.easyflow.ai.service.BotCategoryService; import tech.easyflow.ai.service.BotCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource;
import java.util.Collections;
import java.util.List;
/** /**
* bot分类 控制层。 * bot分类 控制层。
@@ -17,7 +26,24 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/v1/botCategory") @RequestMapping("/api/v1/botCategory")
@UsePermission(moduleName = "/api/v1/bot") @UsePermission(moduleName = "/api/v1/bot")
public class BotCategoryController extends BaseCurdController<BotCategoryService, BotCategory> { public class BotCategoryController extends BaseCurdController<BotCategoryService, BotCategory> {
@Resource
private CategoryPermissionService categoryPermissionService;
public BotCategoryController(BotCategoryService service) { public BotCategoryController(BotCategoryService service) {
super(service); super(service);
} }
}
@GetMapping("visibleList")
public Result<List<BotCategory>> visibleList(BotCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("BOT");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(BotCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -3,9 +3,11 @@ package tech.easyflow.admin.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import cn.dev33.satoken.annotation.SaIgnore; import cn.dev33.satoken.annotation.SaIgnore;
import cn.dev33.satoken.stp.StpUtil;
import com.easyagents.core.model.chat.ChatModel; import com.easyagents.core.model.chat.ChatModel;
import com.easyagents.core.model.chat.ChatOptions; import com.easyagents.core.model.chat.ChatOptions;
import com.alicp.jetcache.Cache; import com.alicp.jetcache.Cache;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator; import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -25,6 +27,8 @@ import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import tech.easyflow.core.chat.protocol.sse.ChatSseUtil; import tech.easyflow.core.chat.protocol.sse.ChatSseUtil;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -34,6 +38,8 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static tech.easyflow.ai.entity.table.BotTableDef.BOT;
/** /**
* 控制层。 * 控制层。
* *
@@ -55,6 +61,8 @@ public class BotController extends BaseCurdController<BotService, Bot> {
private Cache<String, Object> cache; private Cache<String, Object> cache;
@Resource @Resource
private AudioServiceManager audioServiceManager; private AudioServiceManager audioServiceManager;
@Resource
private CategoryPermissionService categoryPermissionService;
public BotController(BotService service, ModelService modelService, BotWorkflowService botWorkflowService, public BotController(BotService service, ModelService modelService, BotWorkflowService botWorkflowService,
BotDocumentCollectionService botDocumentCollectionService, BotMessageService botMessageService) { BotDocumentCollectionService botDocumentCollectionService, BotMessageService botMessageService) {
@@ -164,7 +172,11 @@ public class BotController extends BaseCurdController<BotService, Bot> {
@GetMapping("getDetail") @GetMapping("getDetail")
@SaIgnore @SaIgnore
public Result<Bot> getDetail(String id) { public Result<Bot> getDetail(String id) {
return Result.ok(botService.getDetail(id)); Bot bot = botService.getDetail(id);
if (bot != null && StpUtil.isLogin()) {
categoryPermissionService.assertCategoryResourceVisible("BOT", bot.getCreatedBy(), bot.getCategoryId(), "无权限访问聊天助手");
}
return Result.ok(bot);
} }
@Override @Override
@@ -174,6 +186,9 @@ public class BotController extends BaseCurdController<BotService, Bot> {
if (data == null) { if (data == null) {
return Result.ok(data); return Result.ok(data);
} }
if (StpUtil.isLogin()) {
categoryPermissionService.assertCategoryResourceVisible("BOT", data.getCreatedBy(), data.getCategoryId(), "无权限访问聊天助手");
}
Map<String, Object> llmOptions = data.getModelOptions(); Map<String, Object> llmOptions = data.getModelOptions();
if (llmOptions == null) { if (llmOptions == null) {
@@ -205,6 +220,32 @@ public class BotController extends BaseCurdController<BotService, Bot> {
return Result.ok(data); return Result.ok(data);
} }
@Override
public Result<List<Bot>> list(Bot entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
applyCategoryPermission(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
protected Page<Bot> queryPage(Page<Bot> page, QueryWrapper queryWrapper) {
applyCategoryPermission(queryWrapper);
return super.queryPage(page, queryWrapper);
}
private void applyCategoryPermission(QueryWrapper queryWrapper) {
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("BOT");
if (!access.isRestricted()) {
return;
}
if (access.getCategoryIds().isEmpty()) {
queryWrapper.eq(Bot::getCreatedBy, access.getAccountId());
return;
}
queryWrapper.and(BOT.CREATED_BY.eq(access.getAccountId()).or(BOT.CATEGORY_ID.in(access.getCategoryIds())));
}
@Override @Override
protected Result<?> onSaveOrUpdateBefore(Bot entity, boolean isSave) { protected Result<?> onSaveOrUpdateBefore(Bot entity, boolean isSave) {

View File

@@ -2,7 +2,11 @@ package tech.easyflow.admin.controller.ai;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import tech.easyflow.ai.entity.BotDocumentCollection; import tech.easyflow.ai.entity.BotDocumentCollection;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.permission.KnowledgeReadAccessSnapshot;
import tech.easyflow.ai.permission.KnowledgeVisibilityQueryHelper;
import tech.easyflow.ai.service.BotDocumentCollectionService; import tech.easyflow.ai.service.BotDocumentCollectionService;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
@@ -11,8 +15,13 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.service.ResourceAccessService;
import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
@@ -25,6 +34,13 @@ import java.util.List;
@RequestMapping("/api/v1/botKnowledge") @RequestMapping("/api/v1/botKnowledge")
@UsePermission(moduleName = "/api/v1/bot") @UsePermission(moduleName = "/api/v1/bot")
public class BotDocumentCollectionController extends BaseCurdController<BotDocumentCollectionService, BotDocumentCollection> { public class BotDocumentCollectionController extends BaseCurdController<BotDocumentCollectionService, BotDocumentCollection> {
@Resource
private DocumentCollectionService documentCollectionService;
@Resource
private KnowledgeVisibilityQueryHelper knowledgeVisibilityQueryHelper;
@Resource
private ResourceAccessService resourceAccessService;
public BotDocumentCollectionController(BotDocumentCollectionService service) { public BotDocumentCollectionController(BotDocumentCollectionService service) {
super(service); super(service);
} }
@@ -35,12 +51,32 @@ public class BotDocumentCollectionController extends BaseCurdController<BotDocum
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity)); QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy())); queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
List<BotDocumentCollection> botDocumentCollections = service.getMapper().selectListWithRelationsByQuery(queryWrapper); List<BotDocumentCollection> botDocumentCollections = service.getMapper().selectListWithRelationsByQuery(queryWrapper);
return Result.ok(botDocumentCollections); List<BotDocumentCollection> visibleList = new ArrayList<>();
KnowledgeReadAccessSnapshot snapshot = knowledgeVisibilityQueryHelper.getCurrentReadSnapshot();
for (BotDocumentCollection relation : botDocumentCollections) {
DocumentCollection knowledge = relation.getKnowledge();
if (knowledge == null || knowledgeVisibilityQueryHelper.canRead(knowledge, snapshot)) {
visibleList.add(relation);
}
}
return Result.ok(visibleList);
} }
@PostMapping("updateBotKnowledgeIds") @PostMapping("updateBotKnowledgeIds")
public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("knowledgeIds") BigInteger [] knowledgeIds) { public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("knowledgeIds") BigInteger [] knowledgeIds) {
if (knowledgeIds != null) {
for (BigInteger knowledgeId : knowledgeIds) {
if (knowledgeId == null) {
continue;
}
DocumentCollection collection = documentCollectionService.getById(knowledgeId);
if (collection == null) {
continue;
}
resourceAccessService.assertAccess(CategoryResourceType.KNOWLEDGE, collection, ResourceAction.READ, "无权限绑定知识库");
}
}
service.saveBotAndKnowledge(botId, knowledgeIds); service.saveBotAndKnowledge(botId, knowledgeIds);
return Result.ok(); return Result.ok();
} }
} }

View File

@@ -3,19 +3,25 @@ package tech.easyflow.admin.controller.ai;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import tech.easyflow.ai.entity.Plugin; import tech.easyflow.ai.entity.Plugin;
import tech.easyflow.ai.entity.BotPlugin; import tech.easyflow.ai.entity.BotPlugin;
import tech.easyflow.ai.entity.PluginItem;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.tree.Tree; import tech.easyflow.common.tree.Tree;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.ai.service.BotPluginService; import tech.easyflow.ai.service.BotPluginService;
import tech.easyflow.ai.service.PluginService;
import tech.easyflow.ai.service.PluginItemService;
import tech.easyflow.ai.service.PluginVisibilityService;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
@@ -35,6 +41,12 @@ public class BotPluginController extends BaseCurdController<BotPluginService, Bo
@Resource @Resource
private BotPluginService botPluginService; private BotPluginService botPluginService;
@Resource
private PluginItemService pluginItemService;
@Resource
private PluginService pluginService;
@Resource
private PluginVisibilityService pluginVisibilityService;
@GetMapping("list") @GetMapping("list")
public Result<List<BotPlugin>> list(BotPlugin entity, Boolean asTree, String sortKey, String sortType){ public Result<List<BotPlugin>> list(BotPlugin entity, Boolean asTree, String sortKey, String sortType){
@@ -43,15 +55,29 @@ public class BotPluginController extends BaseCurdController<BotPluginService, Bo
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy())); queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
List<BotPlugin> botPlugins = service.getMapper().selectListWithRelationsByQuery(queryWrapper); List<BotPlugin> botPlugins = service.getMapper().selectListWithRelationsByQuery(queryWrapper);
List<BotPlugin> visibleList = new ArrayList<>();
for (BotPlugin relation : botPlugins) {
Plugin plugin = relation.getAiPlugin();
if (plugin == null || pluginVisibilityService.canAccessPlugin(plugin.getCreatedBy(), plugin.getId())) {
visibleList.add(relation);
}
}
List<BotPlugin> list = Tree.tryToTree(botPlugins, asTree); List<BotPlugin> list = Tree.tryToTree(visibleList, asTree);
return Result.ok(list); return Result.ok(list);
} }
@PostMapping("/getList") @PostMapping("/getList")
public Result<List<Plugin>> getList(@JsonBody(value = "botId", required = true) String botId){ public Result<List<Plugin>> getList(@JsonBody(value = "botId", required = true) String botId){
return Result.ok(botPluginService.getList(botId)); List<Plugin> plugins = botPluginService.getList(botId);
List<Plugin> visibleList = new ArrayList<>();
for (Plugin plugin : plugins) {
if (plugin == null || pluginVisibilityService.canAccessPlugin(plugin.getCreatedBy(), plugin.getId())) {
visibleList.add(plugin);
}
}
return Result.ok(visibleList);
} }
@PostMapping("/getBotPluginToolIds") @PostMapping("/getBotPluginToolIds")
@@ -67,6 +93,23 @@ public class BotPluginController extends BaseCurdController<BotPluginService, Bo
@PostMapping("updateBotPluginToolIds") @PostMapping("updateBotPluginToolIds")
public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("pluginToolIds") BigInteger [] pluginToolIds) { public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("pluginToolIds") BigInteger [] pluginToolIds) {
if (pluginToolIds != null) {
for (BigInteger pluginToolId : pluginToolIds) {
if (pluginToolId == null) {
continue;
}
PluginItem pluginItem = pluginItemService.getById(pluginToolId);
if (pluginItem == null) {
continue;
}
if (pluginItem.getPluginId() != null) {
Plugin plugin = pluginService.getById(pluginItem.getPluginId());
if (plugin != null) {
pluginVisibilityService.assertPluginVisible(plugin.getCreatedBy(), plugin.getId(), "无权限绑定插件");
}
}
}
}
service.saveBotAndPluginTool(botId, pluginToolIds); service.saveBotAndPluginTool(botId, pluginToolIds);
return Result.ok(); return Result.ok();
} }

View File

@@ -2,7 +2,11 @@ package tech.easyflow.admin.controller.ai;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import tech.easyflow.ai.entity.BotWorkflow; import tech.easyflow.ai.entity.BotWorkflow;
import tech.easyflow.ai.entity.Workflow;
import tech.easyflow.ai.permission.WorkflowReadAccessSnapshot;
import tech.easyflow.ai.permission.WorkflowVisibilityQueryHelper;
import tech.easyflow.ai.service.BotWorkflowService; import tech.easyflow.ai.service.BotWorkflowService;
import tech.easyflow.ai.service.WorkflowService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.tree.Tree; import tech.easyflow.common.tree.Tree;
@@ -12,10 +16,16 @@ import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.service.ResourceAccessService;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import javax.annotation.Resource;
/** /**
* 控制层。 * 控制层。
* *
@@ -26,6 +36,13 @@ import java.util.List;
@RequestMapping("/api/v1/botWorkflow") @RequestMapping("/api/v1/botWorkflow")
@UsePermission(moduleName = "/api/v1/bot") @UsePermission(moduleName = "/api/v1/bot")
public class BotWorkflowController extends BaseCurdController<BotWorkflowService, BotWorkflow> { public class BotWorkflowController extends BaseCurdController<BotWorkflowService, BotWorkflow> {
@Resource
private WorkflowService workflowService;
@Resource
private WorkflowVisibilityQueryHelper workflowVisibilityQueryHelper;
@Resource
private ResourceAccessService resourceAccessService;
public BotWorkflowController(BotWorkflowService service) { public BotWorkflowController(BotWorkflowService service) {
super(service); super(service);
} }
@@ -36,13 +53,33 @@ public class BotWorkflowController extends BaseCurdController<BotWorkflowService
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity)); QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy())); queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
List<BotWorkflow> botWorkflows = service.getMapper().selectListWithRelationsByQuery(queryWrapper); List<BotWorkflow> botWorkflows = service.getMapper().selectListWithRelationsByQuery(queryWrapper);
List<BotWorkflow> list = Tree.tryToTree(botWorkflows, asTree); List<BotWorkflow> visibleList = new ArrayList<>();
WorkflowReadAccessSnapshot snapshot = workflowVisibilityQueryHelper.getCurrentReadSnapshot();
for (BotWorkflow botWorkflow : botWorkflows) {
Workflow workflow = botWorkflow.getWorkflow();
if (workflow == null || workflowVisibilityQueryHelper.canRead(workflow, snapshot)) {
visibleList.add(botWorkflow);
}
}
List<BotWorkflow> list = Tree.tryToTree(visibleList, asTree);
return Result.ok(list); return Result.ok(list);
} }
@PostMapping("updateBotWorkflowIds") @PostMapping("updateBotWorkflowIds")
public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("workflowIds") BigInteger [] workflowIds) { public Result<?> save(@JsonBody("botId") BigInteger botId, @JsonBody("workflowIds") BigInteger [] workflowIds) {
if (workflowIds != null) {
for (BigInteger workflowId : workflowIds) {
if (workflowId == null) {
continue;
}
Workflow workflow = workflowService.getById(workflowId);
if (workflow == null) {
continue;
}
resourceAccessService.assertAccess(CategoryResourceType.WORKFLOW, workflow, ResourceAction.READ, "无权限绑定工作流");
}
}
service.saveBotAndWorkflowTool(botId, workflowIds); service.saveBotAndWorkflowTool(botId, workflowIds);
return Result.ok(); return Result.ok();
} }
} }

View File

@@ -2,6 +2,7 @@ package tech.easyflow.admin.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import com.easyagents.core.model.embedding.EmbeddingModel; import com.easyagents.core.model.embedding.EmbeddingModel;
import com.mybatisflex.core.paginate.Page;
import tech.easyflow.ai.entity.DocumentChunk; import tech.easyflow.ai.entity.DocumentChunk;
import tech.easyflow.ai.entity.DocumentCollection; import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.Model; import tech.easyflow.ai.entity.Model;
@@ -16,9 +17,15 @@ import com.easyagents.core.document.Document;
import com.easyagents.core.store.DocumentStore; import com.easyagents.core.store.DocumentStore;
import com.easyagents.core.store.StoreOptions; import com.easyagents.core.store.StoreOptions;
import com.easyagents.core.store.StoreResult; import com.easyagents.core.store.StoreResult;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
@@ -51,8 +58,29 @@ public class DocumentChunkController extends BaseCurdController<DocumentChunkSer
super(service); super(service);
} }
@GetMapping("page")
@SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.DOCUMENT_ID,
idExpr = "#request.getParameter('documentId')",
denyMessage = "无权限访问知识库"
)
@Override
public Result<Page<DocumentChunk>> page(HttpServletRequest request, String sortKey, String sortType, Long pageNumber, Long pageSize) {
return super.page(request, sortKey, sortType, pageNumber, pageSize);
}
@PostMapping("update") @PostMapping("update")
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.DOCUMENT_CHUNK_ID,
idExpr = "#documentChunk.id",
denyMessage = "无权限管理知识库"
)
public Result<?> update(@JsonBody DocumentChunk documentChunk) { public Result<?> update(@JsonBody DocumentChunk documentChunk) {
boolean success = service.updateById(documentChunk); boolean success = service.updateById(documentChunk);
if (success){ if (success){
@@ -87,6 +115,13 @@ public class DocumentChunkController extends BaseCurdController<DocumentChunkSer
@PostMapping("removeChunk") @PostMapping("removeChunk")
@SaCheckPermission("/api/v1/documentCollection/remove") @SaCheckPermission("/api/v1/documentCollection/remove")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.DOCUMENT_CHUNK_ID,
idExpr = "#chunkId",
denyMessage = "无权限管理知识库"
)
public Result<?> remove(@JsonBody(value = "id", required = true) BigInteger chunkId) { public Result<?> remove(@JsonBody(value = "id", required = true) BigInteger chunkId) {
DocumentChunk docChunk = documentChunkService.getById(chunkId); DocumentChunk docChunk = documentChunkService.getById(chunkId);
if (docChunk == null) { if (docChunk == null) {

View File

@@ -3,17 +3,17 @@ package tech.easyflow.admin.controller.ai;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.GetMapping;
import tech.easyflow.ai.entity.DocumentCollection; import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.DocumentCollectionCategory; import tech.easyflow.ai.entity.DocumentCollectionCategory;
import tech.easyflow.ai.entity.WorkflowCategory;
import tech.easyflow.ai.mapper.DocumentCollectionMapper; import tech.easyflow.ai.mapper.DocumentCollectionMapper;
import tech.easyflow.ai.service.DocumentCollectionCategoryService; import tech.easyflow.ai.service.DocumentCollectionCategoryService;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.ai.service.WorkflowCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -34,6 +34,8 @@ public class DocumentCollectionCategoryController extends BaseCurdController<Doc
@Resource @Resource
private DocumentCollectionMapper documentCollectionMapper; private DocumentCollectionMapper documentCollectionMapper;
@Resource
private CategoryPermissionService categoryPermissionService;
public DocumentCollectionCategoryController(DocumentCollectionCategoryService service) { public DocumentCollectionCategoryController(DocumentCollectionCategoryService service) {
super(service); super(service);
@@ -51,4 +53,18 @@ public class DocumentCollectionCategoryController extends BaseCurdController<Doc
return super.onRemoveBefore(ids); return super.onRemoveBefore(ids);
} }
}
@GetMapping("visibleList")
public Result<List<DocumentCollectionCategory>> visibleList(DocumentCollectionCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("KNOWLEDGE");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(DocumentCollectionCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -2,14 +2,19 @@ package tech.easyflow.admin.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import com.easyagents.core.document.Document; import com.easyagents.core.document.Document;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.ai.permission.KnowledgeVisibilityQueryHelper;
import tech.easyflow.ai.documentimport.DocumentImportDtos;
import tech.easyflow.ai.entity.BotDocumentCollection; import tech.easyflow.ai.entity.BotDocumentCollection;
import tech.easyflow.ai.entity.DocumentCollection; import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.service.BotDocumentCollectionService; import tech.easyflow.ai.service.BotDocumentCollectionService;
import tech.easyflow.ai.service.DocumentChunkService; import tech.easyflow.ai.service.DocumentChunkService;
import tech.easyflow.ai.service.DocumentCollectionService; import tech.easyflow.ai.service.DocumentCollectionService;
@@ -17,6 +22,13 @@ import tech.easyflow.ai.service.ModelService;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.enums.VisibilityScope;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import tech.easyflow.system.service.ResourceAccessService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -41,6 +53,10 @@ public class DocumentCollectionController extends BaseCurdController<DocumentCol
@Resource @Resource
private BotDocumentCollectionService botDocumentCollectionService; private BotDocumentCollectionService botDocumentCollectionService;
@Resource
private ResourceAccessService resourceAccessService;
@Resource
private KnowledgeVisibilityQueryHelper knowledgeVisibilityQueryHelper;
public DocumentCollectionController(DocumentCollectionService service, DocumentChunkService chunkService, ModelService llmService) { public DocumentCollectionController(DocumentCollectionService service, DocumentChunkService chunkService, ModelService llmService) {
super(service); super(service);
@@ -50,6 +66,11 @@ public class DocumentCollectionController extends BaseCurdController<DocumentCol
@Override @Override
protected Result<?> onSaveOrUpdateBefore(DocumentCollection entity, boolean isSave) { protected Result<?> onSaveOrUpdateBefore(DocumentCollection entity, boolean isSave) {
normalizeVisibilityScope(entity, isSave);
if (!isSave && entity.getId() != null) {
DocumentCollection existed = requireKnowledge(String.valueOf(entity.getId()));
resourceAccessService.assertAccess(CategoryResourceType.KNOWLEDGE, existed, ResourceAction.MANAGE, "无权限管理知识库");
}
String alias = entity.getAlias(); String alias = entity.getAlias();
String collectionType = entity.getCollectionType(); String collectionType = entity.getCollectionType();
@@ -96,6 +117,13 @@ public class DocumentCollectionController extends BaseCurdController<DocumentCol
@GetMapping("search") @GetMapping("search")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#knowledgeId",
denyMessage = "无权限访问知识库"
)
public Result<List<Document>> search(@RequestParam BigInteger knowledgeId, @RequestParam String keyword) { public Result<List<Document>> search(@RequestParam BigInteger knowledgeId, @RequestParam String keyword) {
return Result.ok(service.search(knowledgeId, keyword)); return Result.ok(service.search(knowledgeId, keyword));
} }
@@ -103,6 +131,10 @@ public class DocumentCollectionController extends BaseCurdController<DocumentCol
@Override @Override
protected Result<Void> onRemoveBefore(Collection<Serializable> ids) { protected Result<Void> onRemoveBefore(Collection<Serializable> ids) {
for (Serializable id : ids) {
DocumentCollection collection = requireKnowledge(String.valueOf(id));
resourceAccessService.assertAccess(CategoryResourceType.KNOWLEDGE, collection, ResourceAction.MANAGE, "无权限管理知识库");
}
QueryWrapper queryWrapper = QueryWrapper.create(); QueryWrapper queryWrapper = QueryWrapper.create();
queryWrapper.in(BotDocumentCollection::getDocumentCollectionId, ids); queryWrapper.in(BotDocumentCollection::getDocumentCollectionId, ids);
@@ -116,7 +148,90 @@ public class DocumentCollectionController extends BaseCurdController<DocumentCol
} }
@Override @Override
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID_OR_SLUG,
idExpr = "#id",
denyMessage = "无权限访问知识库"
)
public Result<DocumentCollection> detail(String id) { public Result<DocumentCollection> detail(String id) {
return Result.ok(service.getDetail(id)); DocumentCollection detail = service.getDetail(id);
return Result.ok(detail);
}
@GetMapping("modelList")
@SaCheckPermission("/api/v1/documentCollection/query")
public Result<List<Model>> modelList(Model entity, Boolean asTree, String sortKey, String sortType) {
return Result.ok(llmService.listSelectableModels(entity, asTree, sortKey, sortType));
}
@PostMapping("splitterProfile/save")
@SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#request.knowledgeId",
denyMessage = "无权限管理知识库"
)
public Result<Boolean> saveSplitterProfile(@JsonBody DocumentImportDtos.SplitterProfileSaveRequest request) {
if (request.getKnowledgeId() == null) {
throw new BusinessException("知识库ID不能为空");
}
DocumentCollection collection = service.getById(request.getKnowledgeId());
if (collection == null) {
throw new BusinessException("知识库不存在");
}
if (collection.isFaqCollection()) {
throw new BusinessException("FAQ知识库不支持文档导入策略");
}
Map<String, Object> options = collection.getOptions() == null
? new HashMap<>()
: new HashMap<>(collection.getOptions());
options.put(DocumentCollection.KEY_SPLITTER_DEFAULT_STRATEGY, request.getDefaultStrategyCode());
options.put(DocumentCollection.KEY_SPLITTER_AUTO_RECOMMEND_ENABLED, request.getAutoRecommendEnabled());
options.put(DocumentCollection.KEY_SPLITTER_FALLBACK_STRATEGY, request.getFallbackStrategyCode());
options.put(DocumentCollection.KEY_SPLITTER_STRATEGY_PROFILES, request.getStrategyProfiles());
DocumentCollection update = new DocumentCollection();
update.setId(collection.getId());
update.setOptions(options);
return Result.ok(service.updateById(update));
}
@Override
public Result<List<DocumentCollection>> list(DocumentCollection entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
knowledgeVisibilityQueryHelper.applyReadableAccess(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
protected Page<DocumentCollection> queryPage(Page<DocumentCollection> page, QueryWrapper queryWrapper) {
knowledgeVisibilityQueryHelper.applyReadableAccess(queryWrapper);
return super.queryPage(page, queryWrapper);
}
private void normalizeVisibilityScope(DocumentCollection entity, boolean isSave) {
if (entity == null) {
return;
}
if (!StringUtils.hasLength(entity.getVisibilityScope())) {
if (isSave) {
entity.setVisibilityScope(VisibilityScope.PRIVATE.name());
}
return;
}
entity.setVisibilityScope(VisibilityScope.from(entity.getVisibilityScope()).name());
}
private DocumentCollection requireKnowledge(String idOrAlias) {
DocumentCollection collection = service.getDetail(idOrAlias);
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return collection;
} }
} }

View File

@@ -1,13 +1,16 @@
package tech.easyflow.admin.controller.ai; package tech.easyflow.admin.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import cn.hutool.core.io.IoUtil;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import tech.easyflow.ai.documentimport.DocumentImportDtos;
import tech.easyflow.ai.entity.Document; import tech.easyflow.ai.entity.Document;
import tech.easyflow.ai.entity.DocumentCollection; import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.DocumentCollectionSplitParams; import tech.easyflow.ai.entity.DocumentCollectionSplitParams;
@@ -24,11 +27,21 @@ import tech.easyflow.common.util.StringUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.common.filestorage.FileStorageService;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import tech.easyflow.system.service.ResourceAccessService;
import javax.annotation.Resource;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable; import java.io.Serializable;
import java.math.BigInteger; import java.math.BigInteger;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@@ -58,6 +71,11 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
@Autowired @Autowired
private RedisLockExecutor redisLockExecutor; private RedisLockExecutor redisLockExecutor;
@Resource(name = "default")
private FileStorageService storageService;
@Autowired
private ResourceAccessService resourceAccessService;
@Value("${easyflow.storage.local.root:}") @Value("${easyflow.storage.local.root:}")
private String fileUploadPath; private String fileUploadPath;
@@ -73,6 +91,8 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
@Transactional @Transactional
@SaCheckPermission("/api/v1/documentCollection/remove") @SaCheckPermission("/api/v1/documentCollection/remove")
public Result<?> remove(@JsonBody(value = "id", required = true) String id) { public Result<?> remove(@JsonBody(value = "id", required = true) String id) {
Document document = requireDocument(new BigInteger(id));
getDocumentCollection(document.getCollectionId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
List<Serializable> ids = Collections.singletonList(id); List<Serializable> ids = Collections.singletonList(id);
Result<?> result = onRemoveBefore(ids); Result<?> result = onRemoveBefore(ids);
if (result != null) return result; if (result != null) return result;
@@ -104,7 +124,7 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
throw new BusinessException("知识库id不能为空"); throw new BusinessException("知识库id不能为空");
} }
DocumentCollection knowledge = getDocumentCollection(kbSlug); DocumentCollection knowledge = getDocumentCollection(kbSlug, ResourceAction.READ, "无权限访问知识库");
QueryWrapper queryWrapper = QueryWrapper.create() QueryWrapper queryWrapper = QueryWrapper.create()
.eq(Document::getCollectionId, knowledge.getId()); .eq(Document::getCollectionId, knowledge.getId());
@@ -121,11 +141,33 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
if (StringUtil.noText(kbSlug)) { if (StringUtil.noText(kbSlug)) {
throw new BusinessException("知识库id不能为空"); throw new BusinessException("知识库id不能为空");
} }
DocumentCollection knowledge = getDocumentCollection(kbSlug); DocumentCollection knowledge = getDocumentCollection(kbSlug, ResourceAction.READ, "无权限访问知识库");
Page<Document> documentList = documentService.getDocumentList(knowledge.getId().toString(), pageSize, pageNumber,fileName); Page<Document> documentList = documentService.getDocumentList(knowledge.getId().toString(), pageSize, pageNumber,fileName);
return Result.ok(documentList); return Result.ok(documentList);
} }
@GetMapping("download")
@SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.DOCUMENT_ID,
idExpr = "#documentId",
denyMessage = "无权限访问知识库"
)
public void download(@RequestParam BigInteger documentId, HttpServletResponse response) throws IOException {
Document document = requireDocument(documentId);
String fileName = resolveDownloadFileName(document);
response.setContentType("application/octet-stream");
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8).replaceAll("\\+", "%20");
response.setHeader("Content-disposition", "attachment;filename*=utf-8''" + encodedFileName);
try (InputStream inputStream = storageService.readStream(document.getDocumentPath())) {
IoUtil.copy(inputStream, response.getOutputStream());
response.flushBuffer();
}
}
@Override @Override
protected String getDefaultOrderBy() { protected String getDefaultOrderBy() {
@@ -138,6 +180,11 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
@Transactional @Transactional
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
public Result<Boolean> update(@JsonBody Document entity) { public Result<Boolean> update(@JsonBody Document entity) {
if (entity.getId() == null) {
throw new BusinessException("文档不存在");
}
Document current = requireDocument(entity.getId());
getDocumentCollection(current.getCollectionId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
super.update(entity); super.update(entity);
return Result.ok(updatePosition(entity)); return Result.ok(updatePosition(entity));
} }
@@ -152,10 +199,40 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
if (documentCollectionSplitParams.getKnowledgeId() == null) { if (documentCollectionSplitParams.getKnowledgeId() == null) {
throw new BusinessException("知识库id不能为空"); throw new BusinessException("知识库id不能为空");
} }
getDocumentCollection(documentCollectionSplitParams.getKnowledgeId().toString()); getDocumentCollection(documentCollectionSplitParams.getKnowledgeId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
return documentService.textSplit(documentCollectionSplitParams); return documentService.textSplit(documentCollectionSplitParams);
} }
@PostMapping("import/analyze")
@SaCheckPermission("/api/v1/documentCollection/save")
public Result<DocumentImportDtos.AnalyzeResponse> analyzeImport(@JsonBody DocumentImportDtos.AnalyzeRequest request) {
if (request.getKnowledgeId() == null) {
throw new BusinessException("知识库id不能为空");
}
getDocumentCollection(request.getKnowledgeId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
return documentService.analyzeImport(request);
}
@PostMapping("import/preview")
@SaCheckPermission("/api/v1/documentCollection/save")
public Result<DocumentImportDtos.PreviewResponse> previewImport(@JsonBody DocumentImportDtos.PreviewRequest request) {
if (request.getKnowledgeId() == null) {
throw new BusinessException("知识库id不能为空");
}
getDocumentCollection(request.getKnowledgeId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
return documentService.previewImport(request);
}
@PostMapping("import/commit")
@SaCheckPermission("/api/v1/documentCollection/save")
public Result<DocumentImportDtos.CommitResponse> commitImport(@JsonBody DocumentImportDtos.CommitRequest request) {
if (request.getKnowledgeId() == null) {
throw new BusinessException("知识库id不能为空");
}
getDocumentCollection(request.getKnowledgeId().toString(), ResourceAction.MANAGE, "无权限管理知识库");
return documentService.commitImport(request);
}
/** /**
* 更新 entity * 更新 entity
* *
@@ -219,17 +296,42 @@ public class DocumentController extends BaseCurdController<DocumentService, Docu
} }
} }
private DocumentCollection getDocumentCollection(String idOrSlug) { private DocumentCollection getDocumentCollection(String idOrSlug, ResourceAction action, String denyMessage) {
DocumentCollection knowledge = StringUtil.isNumeric(idOrSlug) DocumentCollection knowledge = StringUtil.isNumeric(idOrSlug)
? knowledgeService.getById(idOrSlug) ? knowledgeService.getById(idOrSlug)
: knowledgeService.getOne(QueryWrapper.create().eq(DocumentCollection::getSlug, idOrSlug)); : knowledgeService.getOne(QueryWrapper.create().eq(DocumentCollection::getSlug, idOrSlug));
if (knowledge == null) { if (knowledge == null) {
throw new BusinessException("知识库不存在"); throw new BusinessException("知识库不存在");
} }
resourceAccessService.assertAccess(CategoryResourceType.KNOWLEDGE, knowledge, action, denyMessage);
if (knowledge.isFaqCollection()) { if (knowledge.isFaqCollection()) {
throw new BusinessException("FAQ知识库不支持文档操作"); throw new BusinessException("FAQ知识库不支持文档操作");
} }
return knowledge; return knowledge;
} }
private Document requireDocument(BigInteger documentId) {
if (documentId == null) {
throw new BusinessException("文档不存在");
}
Document document = service.getById(documentId);
if (document == null) {
throw new BusinessException("文档不存在");
}
return document;
}
private String resolveDownloadFileName(Document document) {
String fileName = document.getTitle();
if (!StringUtil.hasText(fileName)) {
String path = document.getDocumentPath();
if (!StringUtil.hasText(path)) {
return "document";
}
int slashIndex = Math.max(path.lastIndexOf('/'), path.lastIndexOf('\\'));
return slashIndex >= 0 ? path.substring(slashIndex + 1) : path;
}
return fileName;
}
} }

View File

@@ -12,6 +12,10 @@ import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import java.io.Serializable; import java.io.Serializable;
import java.math.BigInteger; import java.math.BigInteger;
@@ -29,6 +33,13 @@ public class FaqCategoryController extends BaseCurdController<FaqCategoryService
@Override @Override
@GetMapping("list") @GetMapping("list")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#entity == null ? null : #entity.collectionId",
denyMessage = "无权限访问知识库"
)
public Result<List<FaqCategory>> list(FaqCategory entity, Boolean asTree, String sortKey, String sortType) { public Result<List<FaqCategory>> list(FaqCategory entity, Boolean asTree, String sortKey, String sortType) {
BigInteger collectionId = entity == null ? null : entity.getCollectionId(); BigInteger collectionId = entity == null ? null : entity.getCollectionId();
if (collectionId == null) { if (collectionId == null) {
@@ -40,6 +51,13 @@ public class FaqCategoryController extends BaseCurdController<FaqCategoryService
@Override @Override
@PostMapping("save") @PostMapping("save")
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#entity.collectionId",
denyMessage = "无权限管理知识库"
)
public Result<?> save(@JsonBody FaqCategory entity) { public Result<?> save(@JsonBody FaqCategory entity) {
return Result.ok(service.saveCategory(entity)); return Result.ok(service.saveCategory(entity));
} }
@@ -47,6 +65,13 @@ public class FaqCategoryController extends BaseCurdController<FaqCategoryService
@Override @Override
@PostMapping("update") @PostMapping("update")
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.FAQ_CATEGORY_ID,
idExpr = "#entity.id",
denyMessage = "无权限管理知识库"
)
public Result<?> update(@JsonBody FaqCategory entity) { public Result<?> update(@JsonBody FaqCategory entity) {
return Result.ok(service.updateCategory(entity)); return Result.ok(service.updateCategory(entity));
} }
@@ -54,6 +79,13 @@ public class FaqCategoryController extends BaseCurdController<FaqCategoryService
@Override @Override
@PostMapping("remove") @PostMapping("remove")
@SaCheckPermission("/api/v1/documentCollection/remove") @SaCheckPermission("/api/v1/documentCollection/remove")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.FAQ_CATEGORY_ID,
idExpr = "#id",
denyMessage = "无权限管理知识库"
)
public Result<?> remove(@JsonBody(value = "id", required = true) Serializable id) { public Result<?> remove(@JsonBody(value = "id", required = true) Serializable id) {
return Result.ok(service.removeCategory(new BigInteger(String.valueOf(id)))); return Result.ok(service.removeCategory(new BigInteger(String.valueOf(id))));
} }

View File

@@ -25,6 +25,10 @@ import tech.easyflow.common.vo.UploadResVo;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -67,13 +71,31 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@Override @Override
@GetMapping("list") @GetMapping("list")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#entity == null ? null : #entity.collectionId",
denyMessage = "无权限访问知识库"
)
public Result<java.util.List<FaqItem>> list(FaqItem entity, Boolean asTree, String sortKey, String sortType) { public Result<java.util.List<FaqItem>> list(FaqItem entity, Boolean asTree, String sortKey, String sortType) {
BigInteger collectionId = entity == null ? null : entity.getCollectionId();
if (collectionId == null) {
throw new BusinessException("知识库ID不能为空");
}
return super.list(entity, asTree, sortKey, sortType); return super.list(entity, asTree, sortKey, sortType);
} }
@Override @Override
@GetMapping("page") @GetMapping("page")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#request.getParameter('collectionId')",
denyMessage = "无权限访问知识库"
)
public Result<Page<FaqItem>> page(HttpServletRequest request, String sortKey, String sortType, Long pageNumber, Long pageSize) { public Result<Page<FaqItem>> page(HttpServletRequest request, String sortKey, String sortType, Long pageNumber, Long pageSize) {
if (pageNumber == null || pageNumber < 1) { if (pageNumber == null || pageNumber < 1) {
pageNumber = 1L; pageNumber = 1L;
@@ -123,6 +145,13 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@Override @Override
@GetMapping("detail") @GetMapping("detail")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.FAQ_ITEM_ID,
idExpr = "#id",
denyMessage = "无权限访问知识库"
)
public Result<FaqItem> detail(String id) { public Result<FaqItem> detail(String id) {
return super.detail(id); return super.detail(id);
} }
@@ -130,6 +159,13 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@Override @Override
@PostMapping("save") @PostMapping("save")
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#entity.collectionId",
denyMessage = "无权限管理知识库"
)
public Result<?> save(@JsonBody FaqItem entity) { public Result<?> save(@JsonBody FaqItem entity) {
return Result.ok(service.saveFaqItem(entity)); return Result.ok(service.saveFaqItem(entity));
} }
@@ -137,6 +173,13 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@Override @Override
@PostMapping("update") @PostMapping("update")
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.FAQ_ITEM_ID,
idExpr = "#entity.id",
denyMessage = "无权限管理知识库"
)
public Result<?> update(@JsonBody FaqItem entity) { public Result<?> update(@JsonBody FaqItem entity) {
return Result.ok(service.updateFaqItem(entity)); return Result.ok(service.updateFaqItem(entity));
} }
@@ -144,12 +187,26 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@Override @Override
@PostMapping("remove") @PostMapping("remove")
@SaCheckPermission("/api/v1/documentCollection/remove") @SaCheckPermission("/api/v1/documentCollection/remove")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.FAQ_ITEM_ID,
idExpr = "#id",
denyMessage = "无权限管理知识库"
)
public Result<?> remove(@JsonBody(value = "id", required = true) Serializable id) { public Result<?> remove(@JsonBody(value = "id", required = true) Serializable id) {
return Result.ok(service.removeFaqItem(new java.math.BigInteger(String.valueOf(id)))); return Result.ok(service.removeFaqItem(new java.math.BigInteger(String.valueOf(id))));
} }
@PostMapping(value = "uploadImage", produces = MediaType.APPLICATION_JSON_VALUE) @PostMapping(value = "uploadImage", produces = MediaType.APPLICATION_JSON_VALUE)
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#collectionId",
denyMessage = "无权限管理知识库"
)
public Result<UploadResVo> uploadImage(MultipartFile file, BigInteger collectionId) { public Result<UploadResVo> uploadImage(MultipartFile file, BigInteger collectionId) {
if (collectionId == null) { if (collectionId == null) {
throw new BusinessException("知识库ID不能为空"); throw new BusinessException("知识库ID不能为空");
@@ -180,12 +237,26 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@PostMapping(value = "importExcel", consumes = MediaType.MULTIPART_FORM_DATA_VALUE) @PostMapping(value = "importExcel", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
@SaCheckPermission("/api/v1/documentCollection/save") @SaCheckPermission("/api/v1/documentCollection/save")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.MANAGE,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#collectionId",
denyMessage = "无权限管理知识库"
)
public Result<FaqImportResultVo> importExcel(MultipartFile file, BigInteger collectionId) { public Result<FaqImportResultVo> importExcel(MultipartFile file, BigInteger collectionId) {
return Result.ok(service.importFromExcel(collectionId, file)); return Result.ok(service.importFromExcel(collectionId, file));
} }
@GetMapping("downloadImportTemplate") @GetMapping("downloadImportTemplate")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#collectionId",
denyMessage = "无权限访问知识库"
)
public void downloadImportTemplate(BigInteger collectionId, HttpServletResponse response) throws Exception { public void downloadImportTemplate(BigInteger collectionId, HttpServletResponse response) throws Exception {
if (collectionId == null) { if (collectionId == null) {
throw new BusinessException("知识库ID不能为空"); throw new BusinessException("知识库ID不能为空");
@@ -206,6 +277,13 @@ public class FaqItemController extends BaseCurdController<FaqItemService, FaqIte
@GetMapping("exportExcel") @GetMapping("exportExcel")
@SaCheckPermission("/api/v1/documentCollection/query") @SaCheckPermission("/api/v1/documentCollection/query")
@RequireResourceAccess(
resource = CategoryResourceType.KNOWLEDGE,
action = ResourceAction.READ,
lookup = ResourceLookup.KNOWLEDGE_ID,
idExpr = "#collectionId",
denyMessage = "无权限访问知识库"
)
public void exportExcel(BigInteger collectionId, HttpServletResponse response) throws Exception { public void exportExcel(BigInteger collectionId, HttpServletResponse response) throws Exception {
if (collectionId == null) { if (collectionId == null) {
throw new BusinessException("知识库ID不能为空"); throw new BusinessException("知识库ID不能为空");

View File

@@ -7,6 +7,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import tech.easyflow.ai.dto.ModelInvokeConfigDtos;
import tech.easyflow.ai.entity.Model; import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.entity.ModelProvider; import tech.easyflow.ai.entity.ModelProvider;
import tech.easyflow.ai.entity.table.ModelTableDef; import tech.easyflow.ai.entity.table.ModelTableDef;
@@ -69,6 +70,12 @@ public class ModelController extends BaseCurdController<ModelService, Model> {
return Result.ok(modelService.getList(entity)); return Result.ok(modelService.getList(entity));
} }
@GetMapping("invokeList")
@SaCheckPermission("/api/v1/model/query")
public Result<List<Model>> invokeList() {
return Result.ok(modelService.listInvokeModels());
}
@PostMapping("/addAiLlm") @PostMapping("/addAiLlm")
@SaCheckPermission("/api/v1/model/save") @SaCheckPermission("/api/v1/model/save")
public Result<Boolean> addAiLlm(Model entity) { public Result<Boolean> addAiLlm(Model entity) {
@@ -92,6 +99,31 @@ public class ModelController extends BaseCurdController<ModelService, Model> {
return Result.ok(); return Result.ok();
} }
@PostMapping("/updateInvokeConfig")
@SaCheckPermission("/api/v1/model/save")
public Result<Model> updateInvokeConfig(@RequestBody ModelInvokeConfigDtos.UpdateRequest request) {
return Result.ok(modelService.updateInvokeConfig(
request.getId(),
request.getInvokeCode(),
request.getPublishEnabled()
));
}
@PostMapping("/batchUpdateInvokePublishStatus")
@SaCheckPermission("/api/v1/model/save")
public Result<List<Model>> batchUpdateInvokePublishStatus(@RequestBody ModelInvokeConfigDtos.BatchPublishRequest request) {
return Result.ok(modelService.batchUpdateInvokePublishStatus(
request.getIds(),
request.getPublishEnabled()
));
}
@Override
protected Result<?> onSaveOrUpdateBefore(Model entity, boolean isSave) {
modelService.validateForSaveOrUpdate(entity, isSave);
return super.onSaveOrUpdateBefore(entity, isSave);
}
@GetMapping("/selectLlmByProviderCategory") @GetMapping("/selectLlmByProviderCategory")
@SaCheckPermission("/api/v1/model/query") @SaCheckPermission("/api/v1/model/query")
public Result<Map<String, List<Model>>> selectLlmByProviderCategory(Model entity, String sortKey, String sortType) { public Result<Map<String, List<Model>>> selectLlmByProviderCategory(Model entity, String sortKey, String sortType) {

View File

@@ -1,6 +1,7 @@
package tech.easyflow.admin.controller.ai; package tech.easyflow.admin.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RequestParam;
@@ -10,9 +11,13 @@ import tech.easyflow.ai.service.PluginCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.Collections;
import java.util.List;
/** /**
* 控制层。 * 控制层。
@@ -30,6 +35,8 @@ public class PluginCategoryController extends BaseCurdController<PluginCategoryS
@Resource @Resource
private PluginCategoryService pluginCategoryService; private PluginCategoryService pluginCategoryService;
@Resource
private CategoryPermissionService categoryPermissionService;
@GetMapping("/doRemoveCategory") @GetMapping("/doRemoveCategory")
@SaCheckPermission("/api/v1/plugin/remove") @SaCheckPermission("/api/v1/plugin/remove")
@@ -37,4 +44,18 @@ public class PluginCategoryController extends BaseCurdController<PluginCategoryS
return Result.ok(pluginCategoryService.doRemoveCategory(id)); return Result.ok(pluginCategoryService.doRemoveCategory(id));
} }
}
@GetMapping("/visibleList")
public Result<List<PluginCategory>> visibleList(PluginCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("PLUGIN");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(PluginCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -6,16 +6,25 @@ import com.mybatisflex.core.query.QueryWrapper;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.entity.Plugin; import tech.easyflow.ai.entity.Plugin;
import tech.easyflow.ai.service.ModelService;
import tech.easyflow.ai.service.PluginVisibilityService;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.ai.service.PluginService; import tech.easyflow.ai.service.PluginService;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger;
import java.util.List; import java.util.List;
import java.util.Set;
import static tech.easyflow.ai.entity.table.PluginTableDef.PLUGIN;
/** /**
* 控制层。 * 控制层。
@@ -32,6 +41,12 @@ public class PluginController extends BaseCurdController<PluginService, Plugin>
@Resource @Resource
PluginService pluginService; PluginService pluginService;
@Resource
private CategoryPermissionService categoryPermissionService;
@Resource
private PluginVisibilityService pluginVisibilityService;
@Resource
private ModelService modelService;
@Override @Override
protected Result<?> onSaveOrUpdateBefore(Plugin entity, boolean isSave) { protected Result<?> onSaveOrUpdateBefore(Plugin entity, boolean isSave) {
@@ -40,7 +55,7 @@ public class PluginController extends BaseCurdController<PluginService, Plugin>
@PostMapping("/plugin/save") @PostMapping("/plugin/save")
@SaCheckPermission("/api/v1/plugin/save") @SaCheckPermission("/api/v1/plugin/save")
public Result<Boolean> savePlugin(@JsonBody Plugin plugin){ public Result<Plugin> savePlugin(@JsonBody Plugin plugin){
return Result.ok(pluginService.savePlugin(plugin)); return Result.ok(pluginService.savePlugin(plugin));
} }
@@ -62,7 +77,9 @@ public class PluginController extends BaseCurdController<PluginService, Plugin>
@PostMapping("/getList") @PostMapping("/getList")
@SaCheckPermission("/api/v1/plugin/query") @SaCheckPermission("/api/v1/plugin/query")
public Result<List<Plugin>> getList(){ public Result<List<Plugin>> getList(){
return Result.ok(pluginService.getList()); QueryWrapper queryWrapper = QueryWrapper.create().select();
applyCategoryPermission(queryWrapper);
return Result.ok(service.getMapper().selectListByQuery(queryWrapper));
} }
@GetMapping("/pageByCategory") @GetMapping("/pageByCategory")
@@ -76,6 +93,7 @@ public class PluginController extends BaseCurdController<PluginService, Plugin>
} }
if (category == 0){ if (category == 0){
QueryWrapper queryWrapper = buildQueryWrapper(request); QueryWrapper queryWrapper = buildQueryWrapper(request);
applyCategoryPermission(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy())); queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(queryPage(new Page<>(pageNumber, pageSize), queryWrapper)); return Result.ok(queryPage(new Page<>(pageNumber, pageSize), queryWrapper));
} else { } else {
@@ -83,8 +101,41 @@ public class PluginController extends BaseCurdController<PluginService, Plugin>
} }
} }
@GetMapping("/modelList")
@SaCheckPermission("/api/v1/plugin/query")
public Result<List<Model>> modelList(Model entity, Boolean asTree, String sortKey, String sortType) {
return Result.ok(modelService.listSelectableModels(entity, asTree, sortKey, sortType));
}
@Override @Override
protected Page<Plugin> queryPage(Page<Plugin> page, QueryWrapper queryWrapper) { protected Page<Plugin> queryPage(Page<Plugin> page, QueryWrapper queryWrapper) {
applyCategoryPermission(queryWrapper);
return service.getMapper().paginateWithRelations(page, queryWrapper); return service.getMapper().paginateWithRelations(page, queryWrapper);
} }
@Override
public Result<Plugin> detail(String id) {
Plugin plugin = service.getById(id);
if (plugin != null) {
pluginVisibilityService.assertPluginVisible(plugin.getCreatedBy(), plugin.getId(), "无权限访问插件");
}
return Result.ok(plugin);
}
private void applyCategoryPermission(QueryWrapper queryWrapper) {
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("PLUGIN");
if (!access.isRestricted()) {
return;
}
if (access.getCategoryIds().isEmpty()) {
queryWrapper.eq(Plugin::getCreatedBy, access.getAccountIdAsLong());
return;
}
Set<BigInteger> pluginIds = pluginVisibilityService.getCurrentVisiblePluginIds();
if (pluginIds.isEmpty()) {
queryWrapper.eq(Plugin::getCreatedBy, access.getAccountIdAsLong());
return;
}
queryWrapper.and(PLUGIN.CREATED_BY.eq(access.getAccountIdAsLong()).or(PLUGIN.ID.in(pluginIds)));
}
} }

View File

@@ -1,11 +1,20 @@
package tech.easyflow.admin.controller.ai; package tech.easyflow.admin.controller.ai;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.GetMapping;
import tech.easyflow.ai.entity.ResourceCategory; import tech.easyflow.ai.entity.ResourceCategory;
import tech.easyflow.ai.service.ResourceCategoryService; import tech.easyflow.ai.service.ResourceCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource;
import java.util.Collections;
import java.util.List;
/** /**
* 素材分类 * 素材分类
@@ -14,9 +23,24 @@ import tech.easyflow.common.web.controller.BaseCurdController;
@RequestMapping("/api/v1/resourceCategory") @RequestMapping("/api/v1/resourceCategory")
@UsePermission(moduleName = "/api/v1/resource") @UsePermission(moduleName = "/api/v1/resource")
public class ResourceCategoryController extends BaseCurdController<ResourceCategoryService, ResourceCategory> { public class ResourceCategoryController extends BaseCurdController<ResourceCategoryService, ResourceCategory> {
@Resource
private CategoryPermissionService categoryPermissionService;
public ResourceCategoryController(ResourceCategoryService service) { public ResourceCategoryController(ResourceCategoryService service) {
super(service); super(service);
} }
} @GetMapping("visibleList")
public Result<List<ResourceCategory>> visibleList(ResourceCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("RESOURCE");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(ResourceCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -12,10 +12,15 @@ import tech.easyflow.common.domain.Result;
import tech.easyflow.common.entity.LoginAccount; import tech.easyflow.common.entity.LoginAccount;
import tech.easyflow.common.satoken.util.SaTokenUtil; import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.Date; import java.util.Date;
import java.util.List;
import static tech.easyflow.ai.entity.table.ResourceTableDef.RESOURCE;
/** /**
* 素材库 * 素材库
@@ -26,6 +31,9 @@ import java.util.Date;
@RestController @RestController
@RequestMapping("/api/v1/resource") @RequestMapping("/api/v1/resource")
public class ResourceController extends BaseCurdController<ResourceService, Resource> { public class ResourceController extends BaseCurdController<ResourceService, Resource> {
@javax.annotation.Resource
private CategoryPermissionService categoryPermissionService;
public ResourceController(ResourceService service) { public ResourceController(ResourceService service) {
super(service); super(service);
} }
@@ -50,7 +58,36 @@ public class ResourceController extends BaseCurdController<ResourceService, Reso
@Override @Override
protected Page<Resource> queryPage(Page<Resource> page, QueryWrapper queryWrapper) { protected Page<Resource> queryPage(Page<Resource> page, QueryWrapper queryWrapper) {
queryWrapper.eq(Resource::getCreatedBy, SaTokenUtil.getLoginAccount().getId().toString()); applyCategoryPermission(queryWrapper);
return super.queryPage(page, queryWrapper); return super.queryPage(page, queryWrapper);
} }
}
@Override
public Result<List<Resource>> list(Resource entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
applyCategoryPermission(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
public Result<Resource> detail(String id) {
Resource resource = service.getById(id);
if (resource != null) {
categoryPermissionService.assertCategoryResourceVisible("RESOURCE", resource.getCreatedBy(), resource.getCategoryId(), "无权限访问素材");
}
return Result.ok(resource);
}
private void applyCategoryPermission(QueryWrapper queryWrapper) {
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("RESOURCE");
if (!access.isRestricted()) {
return;
}
if (access.getCategoryIds().isEmpty()) {
queryWrapper.eq(Resource::getCreatedBy, access.getAccountId());
return;
}
queryWrapper.and(RESOURCE.CREATED_BY.eq(access.getAccountId()).or(RESOURCE.CATEGORY_ID.in(access.getCategoryIds())));
}
}

View File

@@ -1,12 +1,21 @@
package tech.easyflow.admin.controller.ai; package tech.easyflow.admin.controller.ai;
import com.mybatisflex.core.query.QueryWrapper;
import tech.easyflow.ai.entity.WorkflowCategory; import tech.easyflow.ai.entity.WorkflowCategory;
import tech.easyflow.ai.service.WorkflowCategoryService; import tech.easyflow.ai.service.WorkflowCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.annotation.Resource;
import java.util.Collections;
import java.util.List;
/** /**
* 控制层。 * 控制层。
* *
@@ -17,9 +26,24 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/api/v1/workflowCategory") @RequestMapping("/api/v1/workflowCategory")
@UsePermission(moduleName = "/api/v1/workflow") @UsePermission(moduleName = "/api/v1/workflow")
public class WorkflowCategoryController extends BaseCurdController<WorkflowCategoryService, WorkflowCategory> { public class WorkflowCategoryController extends BaseCurdController<WorkflowCategoryService, WorkflowCategory> {
@Resource
private CategoryPermissionService categoryPermissionService;
public WorkflowCategoryController(WorkflowCategoryService service) { public WorkflowCategoryController(WorkflowCategoryService service) {
super(service); super(service);
} }
} @GetMapping("visibleList")
public Result<List<WorkflowCategory>> visibleList(WorkflowCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("WORKFLOW");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(WorkflowCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -4,6 +4,7 @@ import cn.dev33.satoken.annotation.SaCheckPermission;
import cn.dev33.satoken.stp.StpUtil; import cn.dev33.satoken.stp.StpUtil;
import cn.hutool.core.io.IoUtil; import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.IdUtil; import cn.hutool.core.util.IdUtil;
import com.mybatisflex.core.paginate.Page;
import com.easyagents.flow.core.chain.ChainDefinition; import com.easyagents.flow.core.chain.ChainDefinition;
import com.easyagents.flow.core.chain.Parameter; import com.easyagents.flow.core.chain.Parameter;
import com.easyagents.flow.core.chain.runtime.ChainExecutor; import com.easyagents.flow.core.chain.runtime.ChainExecutor;
@@ -12,6 +13,7 @@ import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import tech.easyflow.ai.permission.WorkflowVisibilityQueryHelper;
import tech.easyflow.ai.easyagentsflow.entity.ChainInfo; import tech.easyflow.ai.easyagentsflow.entity.ChainInfo;
import tech.easyflow.ai.easyagentsflow.entity.NodeInfo; import tech.easyflow.ai.easyagentsflow.entity.NodeInfo;
import tech.easyflow.ai.easyagentsflow.entity.WorkflowCheckResult; import tech.easyflow.ai.easyagentsflow.entity.WorkflowCheckResult;
@@ -30,6 +32,12 @@ import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.enums.VisibilityScope;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import tech.easyflow.system.service.ResourceAccessService;
import tech.easyflow.system.service.SysApiKeyService; import tech.easyflow.system.service.SysApiKeyService;
import javax.annotation.Resource; import javax.annotation.Resource;
@@ -67,6 +75,10 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
private CodeEngineCapabilityService codeEngineCapabilityService; private CodeEngineCapabilityService codeEngineCapabilityService;
@Resource @Resource
private WorkflowCheckService workflowCheckService; private WorkflowCheckService workflowCheckService;
@Resource
private ResourceAccessService resourceAccessService;
@Resource
private WorkflowVisibilityQueryHelper workflowVisibilityQueryHelper;
public WorkflowController(WorkflowService service, ModelService modelService) { public WorkflowController(WorkflowService service, ModelService modelService) {
super(service); super(service);
@@ -78,6 +90,13 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
*/ */
@PostMapping("/singleRun") @PostMapping("/singleRun")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#workflowId",
denyMessage = "无权限运行工作流"
)
public Result<?> singleRun( public Result<?> singleRun(
@JsonBody(value = "workflowId", required = true) BigInteger workflowId, @JsonBody(value = "workflowId", required = true) BigInteger workflowId,
@JsonBody(value = "nodeId", required = true) String nodeId, @JsonBody(value = "nodeId", required = true) String nodeId,
@@ -96,6 +115,13 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
*/ */
@PostMapping("/runAsync") @PostMapping("/runAsync")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限运行工作流"
)
public Result<String> runAsync(@JsonBody(value = "id", required = true) BigInteger id, public Result<String> runAsync(@JsonBody(value = "id", required = true) BigInteger id,
@JsonBody("variables") Map<String, Object> variables) { @JsonBody("variables") Map<String, Object> variables) {
if (variables == null) { if (variables == null) {
@@ -117,6 +143,13 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
* 获取工作流运行状态 - v2 * 获取工作流运行状态 - v2
*/ */
@PostMapping("/getChainStatus") @PostMapping("/getChainStatus")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.EXEC_KEY,
idExpr = "#executeId",
denyMessage = "无权限访问该执行记录"
)
public Result<ChainInfo> getChainStatus(@JsonBody(value = "executeId") String executeId, public Result<ChainInfo> getChainStatus(@JsonBody(value = "executeId") String executeId,
@JsonBody("nodes") List<NodeInfo> nodes) { @JsonBody("nodes") List<NodeInfo> nodes) {
ChainInfo res = tinyFlowService.getChainStatus(executeId, nodes); ChainInfo res = tinyFlowService.getChainStatus(executeId, nodes);
@@ -128,6 +161,13 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
*/ */
@PostMapping("/resume") @PostMapping("/resume")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.EXEC_KEY,
idExpr = "#executeId",
denyMessage = "无权限恢复工作流执行"
)
public Result<Void> resume(@JsonBody(value = "executeId", required = true) String executeId, public Result<Void> resume(@JsonBody(value = "executeId", required = true) String executeId,
@JsonBody("confirmParams") Map<String, Object> confirmParams) { @JsonBody("confirmParams") Map<String, Object> confirmParams) {
chainExecutor.resumeAsync(executeId, confirmParams); chainExecutor.resumeAsync(executeId, confirmParams);
@@ -137,6 +177,10 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
@PostMapping("/importWorkFlow") @PostMapping("/importWorkFlow")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
public Result<Void> importWorkFlow(Workflow workflow, MultipartFile jsonFile) throws Exception { public Result<Void> importWorkFlow(Workflow workflow, MultipartFile jsonFile) throws Exception {
if (workflow.getId() != null) {
Workflow sourceWorkflow = requireWorkflow(String.valueOf(workflow.getId()));
resourceAccessService.assertAccess(CategoryResourceType.WORKFLOW, sourceWorkflow, ResourceAction.MANAGE, "无权限管理工作流");
}
InputStream is = jsonFile.getInputStream(); InputStream is = jsonFile.getInputStream();
String content = IoUtil.read(is, StandardCharsets.UTF_8); String content = IoUtil.read(is, StandardCharsets.UTF_8);
workflow.setContent(content); workflow.setContent(content);
@@ -147,13 +191,30 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
@GetMapping("/exportWorkFlow") @GetMapping("/exportWorkFlow")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<String> exportWorkFlow(BigInteger id) { public Result<String> exportWorkFlow(BigInteger id) {
Workflow workflow = service.getById(id); Workflow workflow = service.getById(id);
if (workflow == null) {
throw new BusinessException("工作流不存在");
}
return Result.ok("", workflow.getContent()); return Result.ok("", workflow.getContent());
} }
@GetMapping("getRunningParameters") @GetMapping("getRunningParameters")
@SaCheckPermission("/api/v1/workflow/query") @SaCheckPermission("/api/v1/workflow/query")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<?> getRunningParameters(@RequestParam BigInteger id) { public Result<?> getRunningParameters(@RequestParam BigInteger id) {
Workflow workflow = service.getById(id); Workflow workflow = service.getById(id);
@@ -186,6 +247,10 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
public Result<WorkflowCheckResult> check(@JsonBody("id") BigInteger id, public Result<WorkflowCheckResult> check(@JsonBody("id") BigInteger id,
@JsonBody("content") String content, @JsonBody("content") String content,
@JsonBody(value = "stage", required = true) String stage) { @JsonBody(value = "stage", required = true) String stage) {
if (id != null) {
Workflow workflow = requireWorkflow(String.valueOf(id));
resourceAccessService.assertAccess(CategoryResourceType.WORKFLOW, workflow, ResourceAction.MANAGE, "无权限管理工作流");
}
WorkflowCheckStage checkStage = WorkflowCheckStage.from(stage); WorkflowCheckStage checkStage = WorkflowCheckStage.from(stage);
WorkflowCheckResult checkResult; WorkflowCheckResult checkResult;
if (StringUtils.hasLength(content)) { if (StringUtils.hasLength(content)) {
@@ -199,6 +264,14 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
} }
@Override @Override
@GetMapping("detail")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<Workflow> detail(String id) { public Result<Workflow> detail(String id) {
Workflow workflow = service.getDetail(id); Workflow workflow = service.getDetail(id);
return Result.ok(workflow); return Result.ok(workflow);
@@ -206,9 +279,19 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
@GetMapping("/copy") @GetMapping("/copy")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<Void> copy(BigInteger id) { public Result<Void> copy(BigInteger id) {
LoginAccount account = SaTokenUtil.getLoginAccount(); LoginAccount account = SaTokenUtil.getLoginAccount();
Workflow workflow = service.getById(id); Workflow workflow = service.getById(id);
if (workflow == null) {
throw new BusinessException("工作流不存在");
}
workflow.setId(null); workflow.setId(null);
workflow.setAlias(IdUtil.fastSimpleUUID()); workflow.setAlias(IdUtil.fastSimpleUUID());
commonFiled(workflow, account.getId(), account.getTenantId(), account.getDeptId()); commonFiled(workflow, account.getId(), account.getTenantId(), account.getDeptId());
@@ -218,6 +301,11 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
@Override @Override
protected Result onSaveOrUpdateBefore(Workflow entity, boolean isSave) { protected Result onSaveOrUpdateBefore(Workflow entity, boolean isSave) {
normalizeVisibilityScope(entity, isSave);
if (!isSave && entity.getId() != null) {
Workflow existed = requireWorkflow(String.valueOf(entity.getId()));
resourceAccessService.assertAccess(CategoryResourceType.WORKFLOW, existed, ResourceAction.MANAGE, "无权限管理工作流");
}
if (StringUtils.hasLength(entity.getContent())) { if (StringUtils.hasLength(entity.getContent())) {
workflowCheckService.checkOrThrow(entity.getContent(), WorkflowCheckStage.SAVE, entity.getId()); workflowCheckService.checkOrThrow(entity.getContent(), WorkflowCheckStage.SAVE, entity.getId());
} }
@@ -241,8 +329,26 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
return null; return null;
} }
@Override
public Result<List<Workflow>> list(Workflow entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
workflowVisibilityQueryHelper.applyReadableAccess(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
protected Page<Workflow> queryPage(Page<Workflow> page, QueryWrapper queryWrapper) {
workflowVisibilityQueryHelper.applyReadableAccess(queryWrapper);
return super.queryPage(page, queryWrapper);
}
@Override @Override
protected Result onRemoveBefore(Collection<Serializable> ids) { protected Result onRemoveBefore(Collection<Serializable> ids) {
for (Serializable id : ids) {
Workflow workflow = requireWorkflow(String.valueOf(id));
resourceAccessService.assertAccess(CategoryResourceType.WORKFLOW, workflow, ResourceAction.MANAGE, "无权限管理工作流");
}
QueryWrapper queryWrapper = QueryWrapper.create(); QueryWrapper queryWrapper = QueryWrapper.create();
queryWrapper.in("workflow_id", ids); queryWrapper.in("workflow_id", ids);
boolean exists = botWorkflowService.exists(queryWrapper); boolean exists = botWorkflowService.exists(queryWrapper);
@@ -251,4 +357,25 @@ public class WorkflowController extends BaseCurdController<WorkflowService, Work
} }
return null; return null;
} }
private void normalizeVisibilityScope(Workflow entity, boolean isSave) {
if (entity == null) {
return;
}
if (!StringUtils.hasLength(entity.getVisibilityScope())) {
if (isSave) {
entity.setVisibilityScope(VisibilityScope.PRIVATE.name());
}
return;
}
entity.setVisibilityScope(VisibilityScope.from(entity.getVisibilityScope()).name());
}
private Workflow requireWorkflow(String idOrAlias) {
Workflow workflow = service.getDetail(idOrAlias);
if (workflow == null) {
throw new BusinessException("工作流不存在");
}
return workflow;
}
} }

View File

@@ -0,0 +1,31 @@
package tech.easyflow.admin.controller.dashboard;
import cn.dev33.satoken.annotation.SaCheckPermission;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.admin.model.dashboard.DashboardOverviewQuery;
import tech.easyflow.admin.model.dashboard.DashboardOverviewVo;
import tech.easyflow.admin.service.dashboard.DashboardService;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.satoken.util.SaTokenUtil;
/**
* 管理端工作台统计接口。
*/
@RestController
@RequestMapping("/api/v1/dashboard")
public class DashboardController {
private final DashboardService dashboardService;
public DashboardController(DashboardService dashboardService) {
this.dashboardService = dashboardService;
}
@GetMapping("/overview")
@SaCheckPermission("/api/v1/dashboard/query")
public Result<DashboardOverviewVo> overview(DashboardOverviewQuery query) {
return Result.ok(dashboardService.getOverview(SaTokenUtil.getLoginAccount(), query));
}
}

View File

@@ -22,6 +22,7 @@ import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.log.annotation.LogRecord; import tech.easyflow.log.annotation.LogRecord;
import tech.easyflow.system.entity.SysAccount; import tech.easyflow.system.entity.SysAccount;
import tech.easyflow.system.entity.vo.SysAccountBatchActionResultVo;
import tech.easyflow.system.entity.vo.SysAccountImportResultVo; import tech.easyflow.system.entity.vo.SysAccountImportResultVo;
import tech.easyflow.system.service.SysAccountService; import tech.easyflow.system.service.SysAccountService;
import tech.easyflow.system.util.SysPasswordPolicy; import tech.easyflow.system.util.SysPasswordPolicy;
@@ -180,6 +181,28 @@ public class SysAccountController extends BaseCurdController<SysAccountService,
return Result.ok(); return Result.ok();
} }
@PostMapping("/removeBatchWithResult")
@SaCheckPermission("/api/v1/sysAccount/remove")
@LogRecord("批量删除用户")
public Result<SysAccountBatchActionResultVo> removeBatchWithResult(
@JsonBody(value = "ids", required = true) List<BigInteger> ids) {
if (ids == null || ids.isEmpty()) {
return Result.fail("ids不能为空", null);
}
return Result.ok(service.removeBatchWithResult(ids));
}
@PostMapping("/resetPasswordBatch")
@SaCheckPermission("/api/v1/sysAccount/save")
@LogRecord("批量重置用户密码")
public Result<SysAccountBatchActionResultVo> resetPasswordBatch(
@JsonBody(value = "ids", required = true) List<BigInteger> ids) {
if (ids == null || ids.isEmpty()) {
return Result.fail("ids不能为空", null);
}
return Result.ok(service.resetPasswordBatch(ids, SaTokenUtil.getLoginAccount().getId()));
}
@PostMapping("/importExcel") @PostMapping("/importExcel")
@SaCheckPermission("/api/v1/sysAccount/save") @SaCheckPermission("/api/v1/sysAccount/save")
public Result<SysAccountImportResultVo> importExcel(MultipartFile file) { public Result<SysAccountImportResultVo> importExcel(MultipartFile file) {

View File

@@ -0,0 +1,54 @@
package tech.easyflow.admin.controller.system;
import cn.dev33.satoken.annotation.SaCheckPermission;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.entity.vo.SysRoleCategoryScopeDetailVo;
import tech.easyflow.system.service.CategoryPermissionService;
import tech.easyflow.system.service.SysRoleCategoryScopeService;
import javax.annotation.Resource;
import java.math.BigInteger;
@RestController
@RequestMapping("/api/v1/sysRoleCategoryScope")
public class SysRoleCategoryScopeController {
@Resource
private SysRoleCategoryScopeService sysRoleCategoryScopeService;
@Resource
private CategoryPermissionService categoryPermissionService;
@GetMapping("/detail")
@SaCheckPermission("/api/v1/sysRole/query")
public Result<SysRoleCategoryScopeDetailVo> detail(BigInteger roleId) {
SysRoleCategoryScopeDetailVo detail = sysRoleCategoryScopeService.getRoleScopeDetail(roleId);
detail.setEditable(categoryPermissionService.isCurrentSuperAdmin());
return Result.ok(detail);
}
@PostMapping("/save")
@SaCheckPermission("/api/v1/sysRole/save")
public Result<Void> save(@JsonBody SysRoleCategoryScopeDetailVo request) {
assertSuperAdmin();
if (request == null || request.getRoleId() == null) {
throw new BusinessException("角色ID不能为空");
}
BigInteger operatorId = SaTokenUtil.getLoginAccount().getId();
sysRoleCategoryScopeService.saveRoleScopes(request.getRoleId(), request.getScopes(), operatorId);
return Result.ok();
}
private void assertSuperAdmin() {
if (!categoryPermissionService.isCurrentSuperAdmin()) {
throw new BusinessException("仅超级管理员可配置分类权限");
}
}
}

View File

@@ -80,13 +80,13 @@ public class SysRoleController extends BaseCurdController<SysRoleService, SysRol
*/ */
@PostMapping("saveRole") @PostMapping("saveRole")
@SaCheckPermission("/api/v1/sysRole/save") @SaCheckPermission("/api/v1/sysRole/save")
public Result<Void> saveRole(@JsonBody SysRole entity) { public Result<BigInteger> saveRole(@JsonBody SysRole entity) {
LoginAccount loginUser = SaTokenUtil.getLoginAccount(); LoginAccount loginUser = SaTokenUtil.getLoginAccount();
if (entity.getId() == null) { if (entity.getId() == null) {
commonFiled(entity, loginUser.getId(), loginUser.getTenantId(), loginUser.getDeptId()); commonFiled(entity, loginUser.getId(), loginUser.getTenantId(), loginUser.getDeptId());
} }
service.saveRole(entity); service.saveRole(entity);
return Result.ok(); return Result.ok(entity.getId());
} }
@Override @Override
@@ -115,4 +115,4 @@ public class SysRoleController extends BaseCurdController<SysRoleService, SysRol
} }
return super.onRemoveBefore(ids); return super.onRemoveBefore(ids);
} }
} }

View File

@@ -0,0 +1,87 @@
package tech.easyflow.admin.model.dashboard;
/**
* 工作台分布/排行项。
*/
public class DashboardDistributionItemVo {
private String key;
private String label;
private Long value;
private Long userTotal;
private Long activeUserTotal;
private Long botTotal;
private Long workflowTotal;
private Long knowledgeBaseTotal;
public String getKey() {
return key;
}
public void setKey(String key) {
this.key = key;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public Long getValue() {
return value;
}
public void setValue(Long value) {
this.value = value;
}
public Long getUserTotal() {
return userTotal;
}
public void setUserTotal(Long userTotal) {
this.userTotal = userTotal;
}
public Long getActiveUserTotal() {
return activeUserTotal;
}
public void setActiveUserTotal(Long activeUserTotal) {
this.activeUserTotal = activeUserTotal;
}
public Long getBotTotal() {
return botTotal;
}
public void setBotTotal(Long botTotal) {
this.botTotal = botTotal;
}
public Long getWorkflowTotal() {
return workflowTotal;
}
public void setWorkflowTotal(Long workflowTotal) {
this.workflowTotal = workflowTotal;
}
public Long getKnowledgeBaseTotal() {
return knowledgeBaseTotal;
}
public void setKnowledgeBaseTotal(Long knowledgeBaseTotal) {
this.knowledgeBaseTotal = knowledgeBaseTotal;
}
}

View File

@@ -0,0 +1,17 @@
package tech.easyflow.admin.model.dashboard;
/**
* 工作台统计查询参数。
*/
public class DashboardOverviewQuery {
private String range;
public String getRange() {
return range;
}
public void setRange(String range) {
this.range = range;
}
}

View File

@@ -0,0 +1,60 @@
package tech.easyflow.admin.model.dashboard;
import java.util.Date;
import java.util.List;
/**
* 工作台总览返回对象。
*/
public class DashboardOverviewVo {
private DashboardSummaryVo summary;
private List<DashboardTrendItemVo> trends;
private List<DashboardDistributionItemVo> distribution;
private DashboardOverviewQuery query;
private Date updatedAt;
public DashboardSummaryVo getSummary() {
return summary;
}
public void setSummary(DashboardSummaryVo summary) {
this.summary = summary;
}
public List<DashboardTrendItemVo> getTrends() {
return trends;
}
public void setTrends(List<DashboardTrendItemVo> trends) {
this.trends = trends;
}
public List<DashboardDistributionItemVo> getDistribution() {
return distribution;
}
public void setDistribution(List<DashboardDistributionItemVo> distribution) {
this.distribution = distribution;
}
public DashboardOverviewQuery getQuery() {
return query;
}
public void setQuery(DashboardOverviewQuery query) {
this.query = query;
}
public Date getUpdatedAt() {
return updatedAt;
}
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
}

View File

@@ -0,0 +1,57 @@
package tech.easyflow.admin.model.dashboard;
/**
* 工作台汇总指标。
*/
public class DashboardSummaryVo {
private Long userTotal;
private Long activeUserTotal;
private Long botTotal;
private Long workflowTotal;
private Long knowledgeBaseTotal;
public Long getUserTotal() {
return userTotal;
}
public void setUserTotal(Long userTotal) {
this.userTotal = userTotal;
}
public Long getActiveUserTotal() {
return activeUserTotal;
}
public void setActiveUserTotal(Long activeUserTotal) {
this.activeUserTotal = activeUserTotal;
}
public Long getBotTotal() {
return botTotal;
}
public void setBotTotal(Long botTotal) {
this.botTotal = botTotal;
}
public Long getWorkflowTotal() {
return workflowTotal;
}
public void setWorkflowTotal(Long workflowTotal) {
this.workflowTotal = workflowTotal;
}
public Long getKnowledgeBaseTotal() {
return knowledgeBaseTotal;
}
public void setKnowledgeBaseTotal(Long knowledgeBaseTotal) {
this.knowledgeBaseTotal = knowledgeBaseTotal;
}
}

View File

@@ -0,0 +1,37 @@
package tech.easyflow.admin.model.dashboard;
/**
* 工作台趋势项。
*/
public class DashboardTrendItemVo {
private String key;
private String label;
private Long activeUserTotal;
public String getKey() {
return key;
}
public void setKey(String key) {
this.key = key;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public Long getActiveUserTotal() {
return activeUserTotal;
}
public void setActiveUserTotal(Long activeUserTotal) {
this.activeUserTotal = activeUserTotal;
}
}

View File

@@ -0,0 +1,13 @@
package tech.easyflow.admin.service.dashboard;
import tech.easyflow.admin.model.dashboard.DashboardOverviewQuery;
import tech.easyflow.admin.model.dashboard.DashboardOverviewVo;
import tech.easyflow.common.entity.LoginAccount;
/**
* 工作台统计服务。
*/
public interface DashboardService {
DashboardOverviewVo getOverview(LoginAccount loginAccount, DashboardOverviewQuery query);
}

View File

@@ -0,0 +1,307 @@
package tech.easyflow.admin.service.dashboard.impl;
import com.easyagents.flow.core.chain.ChainStatus;
import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.row.Db;
import com.mybatisflex.core.row.Row;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import tech.easyflow.admin.model.dashboard.DashboardDistributionItemVo;
import tech.easyflow.admin.model.dashboard.DashboardOverviewQuery;
import tech.easyflow.admin.model.dashboard.DashboardOverviewVo;
import tech.easyflow.admin.model.dashboard.DashboardSummaryVo;
import tech.easyflow.admin.model.dashboard.DashboardTrendItemVo;
import tech.easyflow.admin.service.dashboard.DashboardService;
import tech.easyflow.common.constant.Constants;
import tech.easyflow.common.entity.LoginAccount;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.entity.SysAccountRole;
import tech.easyflow.system.entity.SysRole;
import tech.easyflow.system.service.SysAccountRoleService;
import tech.easyflow.system.service.SysRoleService;
import javax.annotation.Resource;
import java.math.BigInteger;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* 工作台统计服务实现。
*/
@Service
public class DashboardServiceImpl implements DashboardService {
private static final ZoneId DEFAULT_ZONE_ID = ZoneId.systemDefault();
@Resource
private SysAccountRoleService sysAccountRoleService;
@Resource
private SysRoleService sysRoleService;
@Override
public DashboardOverviewVo getOverview(LoginAccount loginAccount, DashboardOverviewQuery query) {
DashboardQueryContext context = buildContext(loginAccount, query);
DashboardSummaryVo summary = buildSummary(context);
List<DashboardTrendItemVo> trends = buildTrends(context);
List<DashboardDistributionItemVo> distribution = buildDistribution(context, summary);
DashboardOverviewVo result = new DashboardOverviewVo();
result.setSummary(summary);
result.setTrends(trends);
result.setDistribution(distribution);
DashboardOverviewQuery normalizedQuery = new DashboardOverviewQuery();
normalizedQuery.setRange(context.range);
result.setQuery(normalizedQuery);
result.setUpdatedAt(new Date());
return result;
}
private DashboardSummaryVo buildSummary(DashboardQueryContext context) {
DashboardSummaryVo summary = new DashboardSummaryVo();
summary.setUserTotal(countScopedTable("tb_sys_account", "a", true, context));
summary.setActiveUserTotal(countActiveUsers(context));
summary.setBotTotal(countScopedTable("tb_bot", "b", false, context));
summary.setWorkflowTotal(countScopedTable("tb_workflow", "w", false, context));
summary.setKnowledgeBaseTotal(countScopedTable("tb_document_collection", "d", false, context));
return summary;
}
private List<DashboardTrendItemVo> buildTrends(DashboardQueryContext context) {
List<TimeBucket> buckets = buildBuckets(context.range);
String bucketFormat = "today".equals(context.range) ? "%Y-%m-%d %H:00:00" : "%Y-%m-%d";
Map<String, Long> activeUserMap = queryActiveUserTrend(context, bucketFormat);
List<DashboardTrendItemVo> items = new ArrayList<>(buckets.size());
for (TimeBucket bucket : buckets) {
long activeUserTotal = activeUserMap.getOrDefault(bucket.key, 0L);
DashboardTrendItemVo item = new DashboardTrendItemVo();
item.setKey(bucket.key);
item.setLabel(bucket.label);
item.setActiveUserTotal(activeUserTotal);
items.add(item);
}
return items;
}
private List<DashboardDistributionItemVo> buildDistribution(DashboardQueryContext context, DashboardSummaryVo summary) {
return buildResourceDistribution(summary);
}
private List<DashboardDistributionItemVo> buildResourceDistribution(DashboardSummaryVo summary) {
List<DashboardDistributionItemVo> items = new ArrayList<>();
items.add(buildPlatformItem("userTotal", "用户总量", summary.getUserTotal()));
items.add(buildPlatformItem("activeUserTotal", "活跃用户", summary.getActiveUserTotal()));
items.add(buildPlatformItem("botTotal", "助手数量", summary.getBotTotal()));
items.add(buildPlatformItem("workflowTotal", "工作流数量", summary.getWorkflowTotal()));
items.add(buildPlatformItem("knowledgeBaseTotal", "知识库数量", summary.getKnowledgeBaseTotal()));
return items;
}
private DashboardDistributionItemVo buildPlatformItem(String key, String label, Long value) {
DashboardDistributionItemVo item = new DashboardDistributionItemVo();
item.setKey(key);
item.setLabel(label);
item.setValue(defaultLong(value));
return item;
}
private Map<String, Long> queryActiveUserTrend(DashboardQueryContext context, String bucketFormat) {
StringBuilder sql = new StringBuilder();
List<Object> params = new ArrayList<>();
sql.append("SELECT DATE_FORMAT(l.created, '").append(bucketFormat).append("') AS bucket_key, ")
.append("COUNT(DISTINCT l.account_id) AS total ")
.append("FROM tb_sys_log l ")
.append("INNER JOIN tb_sys_account a ON a.id = l.account_id AND a.is_deleted IS NULL ")
.append("WHERE l.created >= ? AND l.created < ? ");
params.add(toDate(context.startTime));
params.add(toDate(context.endTime));
appendOptionalTenantFilter(sql, params, context.tenantFilterId, "a.tenant_id");
appendOptionalDeptFilter(sql, params, context.deptFilterId, "a.dept_id");
sql.append("GROUP BY bucket_key ORDER BY bucket_key ASC");
Map<String, Long> data = new HashMap<>();
for (Row row : Db.selectListBySql(sql.toString(), params.toArray())) {
data.put(asString(row.get("bucket_key")), asLong(row.get("total")));
}
return data;
}
private long countScopedTable(String tableName, String alias, boolean containsLogicDelete, DashboardQueryContext context) {
StringBuilder sql = new StringBuilder();
List<Object> params = new ArrayList<>();
sql.append("SELECT COUNT(*) AS total FROM ").append(tableName).append(" ").append(alias).append(" WHERE 1 = 1 ");
if (containsLogicDelete) {
sql.append("AND ").append(alias).append(".is_deleted IS NULL ");
}
appendOptionalTenantFilter(sql, params, context.tenantFilterId, alias + ".tenant_id");
appendOptionalDeptFilter(sql, params, context.deptFilterId, alias + ".dept_id");
return queryForLong(sql.toString(), params);
}
private long countActiveUsers(DashboardQueryContext context) {
StringBuilder sql = new StringBuilder();
List<Object> params = new ArrayList<>();
sql.append("SELECT COUNT(DISTINCT l.account_id) AS total ")
.append("FROM tb_sys_log l ")
.append("INNER JOIN tb_sys_account a ON a.id = l.account_id AND a.is_deleted IS NULL ")
.append("WHERE l.created >= ? AND l.created < ? ");
params.add(toDate(context.startTime));
params.add(toDate(context.endTime));
appendOptionalTenantFilter(sql, params, context.tenantFilterId, "a.tenant_id");
appendOptionalDeptFilter(sql, params, context.deptFilterId, "a.dept_id");
return queryForLong(sql.toString(), params);
}
private long queryForLong(String sql, List<Object> params) {
Object result = Db.selectObject(sql, params.toArray());
return asLong(result);
}
private void appendOptionalTenantFilter(StringBuilder sql, List<Object> params, BigInteger tenantId, String columnName) {
if (tenantId != null) {
sql.append(" AND ").append(columnName).append(" = ? ");
params.add(tenantId);
}
}
private void appendOptionalDeptFilter(StringBuilder sql, List<Object> params, BigInteger deptId, String columnName) {
if (deptId != null) {
sql.append(" AND ").append(columnName).append(" = ? ");
params.add(deptId);
}
}
private DashboardQueryContext buildContext(LoginAccount loginAccount, DashboardOverviewQuery query) {
DashboardQueryContext context = new DashboardQueryContext();
context.range = normalizeRange(query == null ? null : query.getRange());
context.superAdmin = isSuperAdmin(loginAccount);
LocalDate today = LocalDate.now(DEFAULT_ZONE_ID);
if ("today".equals(context.range)) {
context.startTime = LocalDateTime.of(today, LocalTime.MIN);
context.endTime = context.startTime.plusDays(1);
} else if ("7d".equals(context.range)) {
context.startTime = LocalDateTime.of(today.minusDays(6), LocalTime.MIN);
context.endTime = LocalDateTime.of(today.plusDays(1), LocalTime.MIN);
} else {
context.startTime = LocalDateTime.of(today.minusDays(29), LocalTime.MIN);
context.endTime = LocalDateTime.of(today.plusDays(1), LocalTime.MIN);
}
context.tenantFilterId = context.superAdmin ? null : loginAccount.getTenantId();
return context;
}
private boolean isSuperAdmin(LoginAccount loginAccount) {
if (loginAccount == null || loginAccount.getId() == null) {
return false;
}
QueryWrapper roleMappingWrapper = QueryWrapper.create();
roleMappingWrapper.eq(SysAccountRole::getAccountId, loginAccount.getId());
List<BigInteger> roleIds = sysAccountRoleService.list(roleMappingWrapper)
.stream()
.map(SysAccountRole::getRoleId)
.collect(Collectors.toList());
if (roleIds.isEmpty()) {
return Constants.SUPER_ADMIN_ID.equals(loginAccount.getId());
}
QueryWrapper roleWrapper = QueryWrapper.create();
roleWrapper.in(SysRole::getId, roleIds);
roleWrapper.eq(SysRole::getRoleKey, Constants.SUPER_ADMIN_ROLE_CODE);
return sysRoleService.count(roleWrapper) > 0;
}
private String normalizeRange(String range) {
if (!StringUtils.hasText(range)) {
return "7d";
}
if ("today".equals(range) || "7d".equals(range) || "30d".equals(range)) {
return range;
}
throw new BusinessException("不支持的时间范围: " + range);
}
private List<TimeBucket> buildBuckets(String range) {
List<TimeBucket> buckets = new ArrayList<>();
LocalDate today = LocalDate.now(DEFAULT_ZONE_ID);
if ("today".equals(range)) {
DateTimeFormatter keyFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:00:00");
DateTimeFormatter labelFormatter = DateTimeFormatter.ofPattern("HH:00");
LocalDateTime start = LocalDateTime.of(today, LocalTime.MIN);
for (int hour = 0; hour < 24; hour++) {
LocalDateTime current = start.plusHours(hour);
buckets.add(new TimeBucket(current.format(keyFormatter), current.format(labelFormatter)));
}
return buckets;
}
int days = "7d".equals(range) ? 7 : 30;
DateTimeFormatter keyFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd");
DateTimeFormatter labelFormatter = DateTimeFormatter.ofPattern("MM-dd");
LocalDate start = today.minusDays(days - 1L);
for (int i = 0; i < days; i++) {
LocalDate current = start.plusDays(i);
buckets.add(new TimeBucket(current.format(keyFormatter), current.format(labelFormatter)));
}
return buckets;
}
private Date toDate(LocalDateTime dateTime) {
return Date.from(dateTime.atZone(DEFAULT_ZONE_ID).toInstant());
}
private long defaultLong(Long value) {
return value == null ? 0L : value;
}
private String asString(Object value) {
return value == null ? "" : String.valueOf(value);
}
private long asLong(Object value) {
if (value == null) {
return 0L;
}
if (value instanceof Number) {
return ((Number) value).longValue();
}
return Long.parseLong(String.valueOf(value));
}
private static class DashboardQueryContext {
private String range;
private BigInteger tenantFilterId;
private BigInteger deptFilterId;
private boolean superAdmin;
private LocalDateTime startTime;
private LocalDateTime endTime;
}
private static class TimeBucket {
private final String key;
private final String label;
private TimeBucket(String key, String label) {
this.key = key;
this.label = label;
}
}
}

View File

@@ -0,0 +1,176 @@
package tech.easyflow.publicapi.controller;
import cn.dev33.satoken.annotation.SaIgnore;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import tech.easyflow.ai.invoke.mapper.OpenAiProtocolMapper;
import tech.easyflow.ai.invoke.model.UnifiedChatChunk;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionRequest;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiErrorResponse;
import tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver;
import tech.easyflow.ai.invoke.service.UnifiedModelInvokeService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.service.SysApiKeyService;
import javax.annotation.Resource;
import java.io.IOException;
import java.time.Duration;
@SaIgnore
@RestController
@RequestMapping("/v1")
public class PublicModelChatController {
private static final Logger log = LoggerFactory.getLogger(PublicModelChatController.class);
private static final long SSE_TIMEOUT = Duration.ofMinutes(10).toMillis();
@Resource
private SysApiKeyService sysApiKeyService;
@Resource
private UnifiedModelInvokeService unifiedModelInvokeService;
@Resource
private OpenAiProtocolMapper openAiProtocolMapper;
@Resource
private ObjectMapper objectMapper;
/**
* 统一模型调用OpenAI Chat Completions
*/
@PostMapping(
value = "/chat/completions",
produces = {
MediaType.APPLICATION_JSON_VALUE,
MediaType.TEXT_EVENT_STREAM_VALUE
}
)
public Object chatCompletions(@RequestBody String rawBody, HttpServletRequest request) {
try {
String apiKey = resolveApiKey(request.getHeader(HttpHeaders.AUTHORIZATION));
sysApiKeyService.checkApikeyPermission(apiKey, request.getRequestURI());
OpenAiChatCompletionRequest openAiRequest = openAiProtocolMapper.readRequest(rawBody);
UnifiedChatRequest unifiedRequest = openAiProtocolMapper.toUnifiedRequest(openAiRequest);
if (Boolean.TRUE.equals(unifiedRequest.getStream())) {
return createStreamEmitter(unifiedRequest);
}
UnifiedChatResponse response = unifiedModelInvokeService.chat(unifiedRequest);
return buildJsonResponse(openAiProtocolMapper.toOpenAiResponse(response));
} catch (ModelInvokeException e) {
return buildErrorResponse(e);
} catch (BusinessException e) {
return buildErrorResponse(mapBusinessException(e));
} catch (Exception e) {
log.error("chatCompletions unexpected error: {}", e.getMessage(), e);
return buildErrorResponse(new ModelInvokeException(
500,
"统一模型调用失败: " + e.getMessage(),
"api_error",
null,
"internal_error"
));
}
}
private SseEmitter createStreamEmitter(UnifiedChatRequest unifiedRequest) {
SseEmitter emitter = new SseEmitter(SSE_TIMEOUT);
unifiedModelInvokeService.chatStream(unifiedRequest, new UnifiedChatChunkObserver() {
@Override
public void onChunk(UnifiedChatChunk chunk) {
try {
String payload = objectMapper.writeValueAsString(openAiProtocolMapper.toOpenAiChunk(chunk));
emitter.send(SseEmitter.event().data(payload));
} catch (IOException e) {
emitter.completeWithError(e);
}
}
@Override
public void onComplete() {
try {
emitter.send(SseEmitter.event().data("[DONE]"));
emitter.complete();
} catch (IOException e) {
emitter.completeWithError(e);
}
}
@Override
public void onError(Throwable throwable) {
log.error("chatCompletions stream error: {}", throwable.getMessage(), throwable);
emitter.completeWithError(throwable);
}
});
return emitter;
}
private ResponseEntity<String> buildJsonResponse(Object body) {
try {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(objectMapper.writeValueAsString(body));
} catch (Exception e) {
throw new RuntimeException("序列化响应失败", e);
}
}
private ResponseEntity<String> buildErrorResponse(ModelInvokeException e) {
OpenAiErrorResponse response = new OpenAiErrorResponse();
OpenAiErrorResponse.Error error = new OpenAiErrorResponse.Error();
error.setMessage(e.getMessage());
error.setType(e.getType());
error.setParam(e.getParam());
error.setCode(e.getCode());
response.setError(error);
return ResponseEntity.status(e.getStatus())
.contentType(MediaType.APPLICATION_JSON)
.body(writeJson(response));
}
private String writeJson(Object body) {
try {
return objectMapper.writeValueAsString(body);
} catch (Exception ex) {
log.error("writeJson error: {}", ex.getMessage(), ex);
return "{\"error\":{\"message\":\"响应序列化失败\",\"type\":\"api_error\",\"code\":\"serialization_error\"}}";
}
}
private ModelInvokeException mapBusinessException(BusinessException e) {
String message = StrUtil.blankToDefault(e.getMessage(), "访问令牌校验失败");
if (StrUtil.containsAnyIgnoreCase(message, "apikey 不存在", "apikey 已过期", "已禁用")) {
return ModelInvokeException.unauthorized(message);
}
if (StrUtil.containsAnyIgnoreCase(message, "无权限", "接口不存在")) {
return ModelInvokeException.forbidden(message);
}
return ModelInvokeException.badRequest(message);
}
private String resolveApiKey(String authorizationHeader) {
if (StrUtil.isBlank(authorizationHeader)) {
throw ModelInvokeException.unauthorized("Authorization 不能为空");
}
String trimmed = authorizationHeader.trim();
if (StrUtil.startWithIgnoreCase(trimmed, "Bearer ")) {
trimmed = trimmed.substring(7).trim();
}
if (StrUtil.isBlank(trimmed)) {
throw ModelInvokeException.unauthorized("Authorization 无效");
}
return trimmed;
}
}

View File

@@ -1,11 +1,20 @@
package tech.easyflow.usercenter.controller.ai; package tech.easyflow.usercenter.controller.ai;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.GetMapping;
import tech.easyflow.ai.entity.BotCategory; import tech.easyflow.ai.entity.BotCategory;
import tech.easyflow.ai.service.BotCategoryService; import tech.easyflow.ai.service.BotCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource;
import java.util.Collections;
import java.util.List;
/** /**
* bot分类 控制层。 * bot分类 控制层。
@@ -17,7 +26,24 @@ import tech.easyflow.common.web.controller.BaseCurdController;
@RequestMapping("/userCenter/botCategory") @RequestMapping("/userCenter/botCategory")
@UsePermission(moduleName = "/api/v1/bot") @UsePermission(moduleName = "/api/v1/bot")
public class UcBotCategoryController extends BaseCurdController<BotCategoryService, BotCategory> { public class UcBotCategoryController extends BaseCurdController<BotCategoryService, BotCategory> {
@Resource
private CategoryPermissionService categoryPermissionService;
public UcBotCategoryController(BotCategoryService service) { public UcBotCategoryController(BotCategoryService service) {
super(service); super(service);
} }
}
@GetMapping("visibleList")
public Result<List<BotCategory>> visibleList(BotCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("BOT");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(BotCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -3,7 +3,9 @@ package tech.easyflow.usercenter.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import cn.dev33.satoken.annotation.SaIgnore; import cn.dev33.satoken.annotation.SaIgnore;
import cn.dev33.satoken.stp.StpUtil;
import com.alicp.jetcache.Cache; import com.alicp.jetcache.Cache;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator; import com.mybatisflex.core.keygen.impl.SnowFlakeIDKeyGenerator;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import org.slf4j.Logger; import org.slf4j.Logger;
@@ -25,6 +27,8 @@ import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -34,6 +38,8 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import static tech.easyflow.ai.entity.table.BotTableDef.BOT;
/** /**
* 控制层。 * 控制层。
* *
@@ -70,6 +76,8 @@ public class UcBotController extends BaseCurdController<BotService, Bot> {
private BotPluginService botPluginService; private BotPluginService botPluginService;
@Resource @Resource
private BotConversationService conversationMessageService; private BotConversationService conversationMessageService;
@Resource
private CategoryPermissionService categoryPermissionService;
@GetMapping("/generateConversationId") @GetMapping("/generateConversationId")
public Result<Long> generateConversationId() { public Result<Long> generateConversationId() {
@@ -188,7 +196,11 @@ public class UcBotController extends BaseCurdController<BotService, Bot> {
@GetMapping("getDetail") @GetMapping("getDetail")
@SaIgnore @SaIgnore
public Result<Bot> getDetail(String id) { public Result<Bot> getDetail(String id) {
return Result.ok(botService.getDetail(id)); Bot bot = botService.getDetail(id);
if (bot != null && StpUtil.isLogin()) {
categoryPermissionService.assertCategoryResourceVisible("BOT", bot.getCreatedBy(), bot.getCategoryId(), "无权限访问聊天助手");
}
return Result.ok(bot);
} }
@Override @Override
@@ -198,6 +210,9 @@ public class UcBotController extends BaseCurdController<BotService, Bot> {
if (data == null) { if (data == null) {
return Result.ok(data); return Result.ok(data);
} }
if (StpUtil.isLogin()) {
categoryPermissionService.assertCategoryResourceVisible("BOT", data.getCreatedBy(), data.getCategoryId(), "无权限访问聊天助手");
}
Map<String, Object> llmOptions = data.getModelOptions(); Map<String, Object> llmOptions = data.getModelOptions();
if (llmOptions == null) { if (llmOptions == null) {
@@ -229,6 +244,32 @@ public class UcBotController extends BaseCurdController<BotService, Bot> {
return Result.ok(data); return Result.ok(data);
} }
@Override
public Result<List<Bot>> list(Bot entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
applyCategoryPermission(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
protected Page<Bot> queryPage(Page<Bot> page, QueryWrapper queryWrapper) {
applyCategoryPermission(queryWrapper);
return super.queryPage(page, queryWrapper);
}
private void applyCategoryPermission(QueryWrapper queryWrapper) {
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("BOT");
if (!access.isRestricted()) {
return;
}
if (access.getCategoryIds().isEmpty()) {
queryWrapper.eq(Bot::getCreatedBy, access.getAccountId());
return;
}
queryWrapper.and(BOT.CREATED_BY.eq(access.getAccountId()).or(BOT.CATEGORY_ID.in(access.getCategoryIds())));
}
@Override @Override
protected Result<?> onSaveOrUpdateBefore(Bot entity, boolean isSave) { protected Result<?> onSaveOrUpdateBefore(Bot entity, boolean isSave) {

View File

@@ -13,10 +13,15 @@ import tech.easyflow.common.domain.Result;
import tech.easyflow.common.entity.LoginAccount; import tech.easyflow.common.entity.LoginAccount;
import tech.easyflow.common.satoken.util.SaTokenUtil; import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.Date; import java.util.Date;
import java.util.List;
import static tech.easyflow.ai.entity.table.ResourceTableDef.RESOURCE;
/** /**
* 素材库 * 素材库
@@ -28,6 +33,9 @@ import java.util.Date;
@RequestMapping("/userCenter/resource") @RequestMapping("/userCenter/resource")
@UsePermission(moduleName = "/api/v1/resource") @UsePermission(moduleName = "/api/v1/resource")
public class UcResourceController extends BaseCurdController<ResourceService, Resource> { public class UcResourceController extends BaseCurdController<ResourceService, Resource> {
@javax.annotation.Resource
private CategoryPermissionService categoryPermissionService;
public UcResourceController(ResourceService service) { public UcResourceController(ResourceService service) {
super(service); super(service);
} }
@@ -52,7 +60,36 @@ public class UcResourceController extends BaseCurdController<ResourceService, Re
@Override @Override
protected Page<Resource> queryPage(Page<Resource> page, QueryWrapper queryWrapper) { protected Page<Resource> queryPage(Page<Resource> page, QueryWrapper queryWrapper) {
queryWrapper.eq(Resource::getCreatedBy, SaTokenUtil.getLoginAccount().getId().toString()); applyCategoryPermission(queryWrapper);
return super.queryPage(page, queryWrapper); return super.queryPage(page, queryWrapper);
} }
}
@Override
public Result<List<Resource>> list(Resource entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
applyCategoryPermission(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
public Result<Resource> detail(String id) {
Resource resource = service.getById(id);
if (resource != null) {
categoryPermissionService.assertCategoryResourceVisible("RESOURCE", resource.getCreatedBy(), resource.getCategoryId(), "无权限访问素材");
}
return Result.ok(resource);
}
private void applyCategoryPermission(QueryWrapper queryWrapper) {
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("RESOURCE");
if (!access.isRestricted()) {
return;
}
if (access.getCategoryIds().isEmpty()) {
queryWrapper.eq(Resource::getCreatedBy, access.getAccountId());
return;
}
queryWrapper.and(RESOURCE.CREATED_BY.eq(access.getAccountId()).or(RESOURCE.CATEGORY_ID.in(access.getCategoryIds())));
}
}

View File

@@ -1,15 +1,22 @@
package tech.easyflow.usercenter.controller.ai; package tech.easyflow.usercenter.controller.ai;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.GetMapping;
import tech.easyflow.ai.entity.WorkflowCategory; import tech.easyflow.ai.entity.WorkflowCategory;
import tech.easyflow.ai.service.WorkflowCategoryService; import tech.easyflow.ai.service.WorkflowCategoryService;
import tech.easyflow.common.annotation.UsePermission; import tech.easyflow.common.annotation.UsePermission;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List;
/** /**
* 工作流分类 * 工作流分类
@@ -21,6 +28,8 @@ import java.util.Collection;
@RequestMapping("/userCenter/workflowCategory") @RequestMapping("/userCenter/workflowCategory")
@UsePermission(moduleName = "/api/v1/workflow") @UsePermission(moduleName = "/api/v1/workflow")
public class UcWorkflowCategoryController extends BaseCurdController<WorkflowCategoryService, WorkflowCategory> { public class UcWorkflowCategoryController extends BaseCurdController<WorkflowCategoryService, WorkflowCategory> {
@Resource
private CategoryPermissionService categoryPermissionService;
public UcWorkflowCategoryController(WorkflowCategoryService service) { public UcWorkflowCategoryController(WorkflowCategoryService service) {
super(service); super(service);
@@ -35,4 +44,18 @@ public class UcWorkflowCategoryController extends BaseCurdController<WorkflowCat
protected Result<?> onRemoveBefore(Collection<Serializable> ids) { protected Result<?> onRemoveBefore(Collection<Serializable> ids) {
return Result.fail("-"); return Result.fail("-");
} }
}
@GetMapping("visibleList")
public Result<List<WorkflowCategory>> visibleList(WorkflowCategory entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("WORKFLOW");
if (access.isRestricted()) {
if (access.getCategoryIds().isEmpty()) {
return Result.ok(Collections.emptyList());
}
queryWrapper.in(WorkflowCategory::getId, access.getCategoryIds());
}
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
}

View File

@@ -2,11 +2,14 @@ package tech.easyflow.usercenter.controller.ai;
import cn.dev33.satoken.annotation.SaCheckPermission; import cn.dev33.satoken.annotation.SaCheckPermission;
import cn.dev33.satoken.stp.StpUtil; import cn.dev33.satoken.stp.StpUtil;
import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper;
import com.easyagents.flow.core.chain.ChainDefinition; import com.easyagents.flow.core.chain.ChainDefinition;
import com.easyagents.flow.core.chain.Parameter; import com.easyagents.flow.core.chain.Parameter;
import com.easyagents.flow.core.chain.runtime.ChainExecutor; import com.easyagents.flow.core.chain.runtime.ChainExecutor;
import com.easyagents.flow.core.parser.ChainParser; import com.easyagents.flow.core.parser.ChainParser;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import tech.easyflow.ai.permission.WorkflowVisibilityQueryHelper;
import tech.easyflow.ai.easyagentsflow.entity.ChainInfo; import tech.easyflow.ai.easyagentsflow.entity.ChainInfo;
import tech.easyflow.ai.easyagentsflow.entity.NodeInfo; import tech.easyflow.ai.easyagentsflow.entity.NodeInfo;
import tech.easyflow.ai.easyagentsflow.entity.WorkflowCheckStage; import tech.easyflow.ai.easyagentsflow.entity.WorkflowCheckStage;
@@ -20,6 +23,10 @@ import tech.easyflow.common.domain.Result;
import tech.easyflow.common.satoken.util.SaTokenUtil; import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.common.web.controller.BaseCurdController; import tech.easyflow.common.web.controller.BaseCurdController;
import tech.easyflow.common.web.jsonbody.JsonBody; import tech.easyflow.common.web.jsonbody.JsonBody;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceAction;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.RequireResourceAccess;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.io.Serializable; import java.io.Serializable;
@@ -45,6 +52,8 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
private TinyFlowService tinyFlowService; private TinyFlowService tinyFlowService;
@Resource @Resource
private WorkflowCheckService workflowCheckService; private WorkflowCheckService workflowCheckService;
@Resource
private WorkflowVisibilityQueryHelper workflowVisibilityQueryHelper;
public UcWorkflowController(WorkflowService service) { public UcWorkflowController(WorkflowService service) {
super(service); super(service);
@@ -55,6 +64,13 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
*/ */
@PostMapping("/singleRun") @PostMapping("/singleRun")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#workflowId",
denyMessage = "无权限运行工作流"
)
public Result<?> singleRun( public Result<?> singleRun(
@JsonBody(value = "workflowId", required = true) BigInteger workflowId, @JsonBody(value = "workflowId", required = true) BigInteger workflowId,
@JsonBody(value = "nodeId", required = true) String nodeId, @JsonBody(value = "nodeId", required = true) String nodeId,
@@ -73,6 +89,13 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
*/ */
@PostMapping("/runAsync") @PostMapping("/runAsync")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限运行工作流"
)
public Result<String> runAsync(@JsonBody(value = "id", required = true) BigInteger id, public Result<String> runAsync(@JsonBody(value = "id", required = true) BigInteger id,
@JsonBody("variables") Map<String, Object> variables) { @JsonBody("variables") Map<String, Object> variables) {
if (variables == null) { if (variables == null) {
@@ -94,6 +117,13 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
* 获取工作流运行状态 - v2 * 获取工作流运行状态 - v2
*/ */
@PostMapping("/getChainStatus") @PostMapping("/getChainStatus")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.EXEC_KEY,
idExpr = "#executeId",
denyMessage = "无权限访问该执行记录"
)
public Result<ChainInfo> getChainStatus(@JsonBody(value = "executeId") String executeId, public Result<ChainInfo> getChainStatus(@JsonBody(value = "executeId") String executeId,
@JsonBody("nodes") List<NodeInfo> nodes) { @JsonBody("nodes") List<NodeInfo> nodes) {
ChainInfo res = tinyFlowService.getChainStatus(executeId, nodes); ChainInfo res = tinyFlowService.getChainStatus(executeId, nodes);
@@ -105,6 +135,13 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
*/ */
@PostMapping("/resume") @PostMapping("/resume")
@SaCheckPermission("/api/v1/workflow/save") @SaCheckPermission("/api/v1/workflow/save")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.USE,
lookup = ResourceLookup.EXEC_KEY,
idExpr = "#executeId",
denyMessage = "无权限恢复工作流执行"
)
public Result<Void> resume(@JsonBody(value = "executeId", required = true) String executeId, public Result<Void> resume(@JsonBody(value = "executeId", required = true) String executeId,
@JsonBody("confirmParams") Map<String, Object> confirmParams) { @JsonBody("confirmParams") Map<String, Object> confirmParams) {
chainExecutor.resumeAsync(executeId, confirmParams); chainExecutor.resumeAsync(executeId, confirmParams);
@@ -116,6 +153,13 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
*/ */
@GetMapping("getRunningParameters") @GetMapping("getRunningParameters")
@SaCheckPermission("/api/v1/workflow/query") @SaCheckPermission("/api/v1/workflow/query")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<?> getRunningParameters(@RequestParam BigInteger id) { public Result<?> getRunningParameters(@RequestParam BigInteger id) {
Workflow workflow = service.getById(id); Workflow workflow = service.getById(id);
@@ -146,4 +190,32 @@ public class UcWorkflowController extends BaseCurdController<WorkflowService, Wo
protected Result<?> onRemoveBefore(Collection<Serializable> ids) { protected Result<?> onRemoveBefore(Collection<Serializable> ids) {
return Result.fail("-"); return Result.fail("-");
} }
@Override
public Result<List<Workflow>> list(Workflow entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(entity, buildOperators(entity));
workflowVisibilityQueryHelper.applyReadableAccess(queryWrapper);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType, getDefaultOrderBy()));
return Result.ok(service.list(queryWrapper));
}
@Override
protected Page<Workflow> queryPage(Page<Workflow> page, QueryWrapper queryWrapper) {
workflowVisibilityQueryHelper.applyReadableAccess(queryWrapper);
return super.queryPage(page, queryWrapper);
}
@Override
@GetMapping("detail")
@RequireResourceAccess(
resource = CategoryResourceType.WORKFLOW,
action = ResourceAction.READ,
lookup = ResourceLookup.WORKFLOW_ID,
idExpr = "#id",
denyMessage = "无权限访问工作流"
)
public Result<Workflow> detail(String id) {
Workflow workflow = service.getDetail(id);
return Result.ok(workflow);
}
} }

View File

@@ -41,6 +41,10 @@
<groupId>com.easyagents</groupId> <groupId>com.easyagents</groupId>
<artifactId>easy-agents-support</artifactId> <artifactId>easy-agents-support</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.easyagents</groupId>
<artifactId>easy-agents-spring-boot-starter</artifactId>
</dependency>
<!--使用 <!--使用
enjoy 模板引擎--> enjoy 模板引擎-->
<dependency> <dependency>

View File

@@ -0,0 +1,555 @@
package tech.easyflow.ai.documentimport;
import com.easyagents.rag.core.RagChunk;
import com.easyagents.rag.ingestion.model.AnalysisResult;
import com.easyagents.rag.ingestion.model.StrategyConfig;
import tech.easyflow.ai.entity.Document;
import tech.easyflow.ai.entity.DocumentChunk;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public final class DocumentImportDtos {
private DocumentImportDtos() {
}
public static class FileItem implements Serializable {
private String filePath;
private String fileName;
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
}
public static class AnalyzeRequest implements Serializable {
private BigInteger knowledgeId;
private List<FileItem> files = new ArrayList<FileItem>();
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public List<FileItem> getFiles() {
return files;
}
public void setFiles(List<FileItem> files) {
this.files = files;
}
}
public static class PreviewFileRequest implements Serializable {
private String filePath;
private String fileName;
private StrategyConfig strategyConfig = StrategyConfig.defaults();
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public StrategyConfig getStrategyConfig() {
return strategyConfig;
}
public void setStrategyConfig(StrategyConfig strategyConfig) {
this.strategyConfig = strategyConfig;
}
}
public static class PreviewRequest implements Serializable {
private BigInteger knowledgeId;
private List<PreviewFileRequest> files = new ArrayList<PreviewFileRequest>();
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public List<PreviewFileRequest> getFiles() {
return files;
}
public void setFiles(List<PreviewFileRequest> files) {
this.files = files;
}
}
public static class CommitRequest implements Serializable {
private BigInteger knowledgeId;
private List<String> previewSessionIds = new ArrayList<String>();
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public List<String> getPreviewSessionIds() {
return previewSessionIds;
}
public void setPreviewSessionIds(List<String> previewSessionIds) {
this.previewSessionIds = previewSessionIds;
}
}
public static class SplitterProfileSaveRequest implements Serializable {
private BigInteger knowledgeId;
private String defaultStrategyCode;
private Boolean autoRecommendEnabled;
private String fallbackStrategyCode;
private Map<String, Object> strategyProfiles = new LinkedHashMap<String, Object>();
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getDefaultStrategyCode() {
return defaultStrategyCode;
}
public void setDefaultStrategyCode(String defaultStrategyCode) {
this.defaultStrategyCode = defaultStrategyCode;
}
public Boolean getAutoRecommendEnabled() {
return autoRecommendEnabled;
}
public void setAutoRecommendEnabled(Boolean autoRecommendEnabled) {
this.autoRecommendEnabled = autoRecommendEnabled;
}
public String getFallbackStrategyCode() {
return fallbackStrategyCode;
}
public void setFallbackStrategyCode(String fallbackStrategyCode) {
this.fallbackStrategyCode = fallbackStrategyCode;
}
public Map<String, Object> getStrategyProfiles() {
return strategyProfiles;
}
public void setStrategyProfiles(Map<String, Object> strategyProfiles) {
this.strategyProfiles = strategyProfiles;
}
}
public static class AnalyzeItem implements Serializable {
private String filePath;
private String fileName;
private AnalysisResult analysis;
private StrategyConfig strategyConfig = StrategyConfig.defaults();
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public AnalysisResult getAnalysis() {
return analysis;
}
public void setAnalysis(AnalysisResult analysis) {
this.analysis = analysis;
}
public StrategyConfig getStrategyConfig() {
return strategyConfig;
}
public void setStrategyConfig(StrategyConfig strategyConfig) {
this.strategyConfig = strategyConfig;
}
}
public static class AnalyzeResponse implements Serializable {
private Integer total;
private List<AnalyzeItem> items = new ArrayList<AnalyzeItem>();
public Integer getTotal() {
return total;
}
public void setTotal(Integer total) {
this.total = total;
}
public List<AnalyzeItem> getItems() {
return items;
}
public void setItems(List<AnalyzeItem> items) {
this.items = items;
}
}
public static class PreviewFileResult implements Serializable {
private String previewSessionId;
private String filePath;
private String fileName;
private String strategyCode;
private String strategyLabel;
private AnalysisResult analysis;
private Integer totalChunks;
private Integer totalWarnings;
private List<RagChunk> chunks = new ArrayList<RagChunk>();
public String getPreviewSessionId() {
return previewSessionId;
}
public void setPreviewSessionId(String previewSessionId) {
this.previewSessionId = previewSessionId;
}
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public String getStrategyCode() {
return strategyCode;
}
public void setStrategyCode(String strategyCode) {
this.strategyCode = strategyCode;
}
public String getStrategyLabel() {
return strategyLabel;
}
public void setStrategyLabel(String strategyLabel) {
this.strategyLabel = strategyLabel;
}
public AnalysisResult getAnalysis() {
return analysis;
}
public void setAnalysis(AnalysisResult analysis) {
this.analysis = analysis;
}
public Integer getTotalChunks() {
return totalChunks;
}
public void setTotalChunks(Integer totalChunks) {
this.totalChunks = totalChunks;
}
public Integer getTotalWarnings() {
return totalWarnings;
}
public void setTotalWarnings(Integer totalWarnings) {
this.totalWarnings = totalWarnings;
}
public List<RagChunk> getChunks() {
return chunks;
}
public void setChunks(List<RagChunk> chunks) {
this.chunks = chunks;
}
}
public static class PreviewResponse implements Serializable {
private Integer totalFiles;
private Integer totalChunks;
private List<PreviewFileResult> items = new ArrayList<PreviewFileResult>();
public Integer getTotalFiles() {
return totalFiles;
}
public void setTotalFiles(Integer totalFiles) {
this.totalFiles = totalFiles;
}
public Integer getTotalChunks() {
return totalChunks;
}
public void setTotalChunks(Integer totalChunks) {
this.totalChunks = totalChunks;
}
public List<PreviewFileResult> getItems() {
return items;
}
public void setItems(List<PreviewFileResult> items) {
this.items = items;
}
}
public static class CommitFileResult implements Serializable {
private String previewSessionId;
private String fileName;
private Boolean success;
private String reason;
private BigInteger documentId;
private Integer chunkCount;
public String getPreviewSessionId() {
return previewSessionId;
}
public void setPreviewSessionId(String previewSessionId) {
this.previewSessionId = previewSessionId;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public Boolean getSuccess() {
return success;
}
public void setSuccess(Boolean success) {
this.success = success;
}
public String getReason() {
return reason;
}
public void setReason(String reason) {
this.reason = reason;
}
public BigInteger getDocumentId() {
return documentId;
}
public void setDocumentId(BigInteger documentId) {
this.documentId = documentId;
}
public Integer getChunkCount() {
return chunkCount;
}
public void setChunkCount(Integer chunkCount) {
this.chunkCount = chunkCount;
}
}
public static class CommitResponse implements Serializable {
private Integer totalFiles;
private Integer successCount;
private Integer errorCount;
private List<CommitFileResult> results = new ArrayList<CommitFileResult>();
public Integer getTotalFiles() {
return totalFiles;
}
public void setTotalFiles(Integer totalFiles) {
this.totalFiles = totalFiles;
}
public Integer getSuccessCount() {
return successCount;
}
public void setSuccessCount(Integer successCount) {
this.successCount = successCount;
}
public Integer getErrorCount() {
return errorCount;
}
public void setErrorCount(Integer errorCount) {
this.errorCount = errorCount;
}
public List<CommitFileResult> getResults() {
return results;
}
public void setResults(List<CommitFileResult> results) {
this.results = results;
}
}
public static class PreviewSession implements Serializable {
private String sessionId;
private BigInteger knowledgeId;
private String filePath;
private String fileName;
private String sourceFormat;
private StrategyConfig strategyConfig;
private AnalysisResult analysis;
private Document document;
private List<DocumentChunk> documentChunks = new ArrayList<DocumentChunk>();
private List<RagChunk> previewChunks = new ArrayList<RagChunk>();
private Date createdAt;
public String getSessionId() {
return sessionId;
}
public void setSessionId(String sessionId) {
this.sessionId = sessionId;
}
public BigInteger getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(BigInteger knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getFilePath() {
return filePath;
}
public void setFilePath(String filePath) {
this.filePath = filePath;
}
public String getFileName() {
return fileName;
}
public void setFileName(String fileName) {
this.fileName = fileName;
}
public String getSourceFormat() {
return sourceFormat;
}
public void setSourceFormat(String sourceFormat) {
this.sourceFormat = sourceFormat;
}
public StrategyConfig getStrategyConfig() {
return strategyConfig;
}
public void setStrategyConfig(StrategyConfig strategyConfig) {
this.strategyConfig = strategyConfig;
}
public AnalysisResult getAnalysis() {
return analysis;
}
public void setAnalysis(AnalysisResult analysis) {
this.analysis = analysis;
}
public Document getDocument() {
return document;
}
public void setDocument(Document document) {
this.document = document;
}
public List<DocumentChunk> getDocumentChunks() {
return documentChunks;
}
public void setDocumentChunks(List<DocumentChunk> documentChunks) {
this.documentChunks = documentChunks;
}
public List<RagChunk> getPreviewChunks() {
return previewChunks;
}
public void setPreviewChunks(List<RagChunk> previewChunks) {
this.previewChunks = previewChunks;
}
public Date getCreatedAt() {
return createdAt;
}
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
}
}

View File

@@ -0,0 +1,21 @@
package tech.easyflow.ai.documentimport;
public final class DocumentImportKeys {
private DocumentImportKeys() {
}
public static final String CACHE_KEY_PREFIX = "easyflow:document:import:preview:";
public static final String KEY_SPLITTER_DEFAULT_STRATEGY = "splitter.defaultStrategyCode";
public static final String KEY_SPLITTER_AUTO_RECOMMEND_ENABLED = "splitter.autoRecommendEnabled";
public static final String KEY_SPLITTER_FALLBACK_STRATEGY = "splitter.fallbackStrategyCode";
public static final String KEY_SPLITTER_STRATEGY_PROFILES = "splitter.strategyProfiles";
public static final String KEY_DOCUMENT_STRATEGY_CODE = "splitter.strategyCode";
public static final String KEY_DOCUMENT_STRATEGY_LABEL = "splitter.strategyLabel";
public static final String KEY_DOCUMENT_STRATEGY_SNAPSHOT = "splitter.strategySnapshot";
public static final String KEY_DOCUMENT_ANALYSIS_SUMMARY = "splitter.analysisSummary";
public static final String KEY_DOCUMENT_SOURCE_FILE_EXT = "splitter.sourceFileExt";
public static final String KEY_DOCUMENT_PREVIEW_VERSION = "splitter.previewVersion";
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.documentimport;
import com.alicp.jetcache.Cache;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import tech.easyflow.common.web.exceptions.BusinessException;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.UUID;
@Service
public class DocumentImportPreviewService {
private static final Duration SESSION_TTL = Duration.ofMinutes(30);
private final Cache<String, Object> defaultCache;
public DocumentImportPreviewService(@Qualifier("defaultCache") Cache<String, Object> defaultCache) {
this.defaultCache = defaultCache;
}
public String put(DocumentImportDtos.PreviewSession session) {
String sessionId = UUID.randomUUID().toString().replace("-", "");
session.setSessionId(sessionId);
defaultCache.put(buildKey(sessionId), session, SESSION_TTL.toMinutes(), TimeUnit.MINUTES);
return sessionId;
}
public DocumentImportDtos.PreviewSession getRequired(String sessionId) {
Object cached = defaultCache.get(buildKey(sessionId));
if (!(cached instanceof DocumentImportDtos.PreviewSession)) {
throw new BusinessException("预览会话已失效,请重新生成预览");
}
return (DocumentImportDtos.PreviewSession) cached;
}
public void remove(String sessionId) {
defaultCache.remove(buildKey(sessionId));
}
private String buildKey(String sessionId) {
return DocumentImportKeys.CACHE_KEY_PREFIX + sessionId;
}
}

View File

@@ -0,0 +1,65 @@
package tech.easyflow.ai.dto;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.List;
public class ModelInvokeConfigDtos {
public static class UpdateRequest implements Serializable {
private static final long serialVersionUID = 1L;
private BigInteger id;
private String invokeCode;
private Boolean publishEnabled;
public BigInteger getId() {
return id;
}
public void setId(BigInteger id) {
this.id = id;
}
public String getInvokeCode() {
return invokeCode;
}
public void setInvokeCode(String invokeCode) {
this.invokeCode = invokeCode;
}
public Boolean getPublishEnabled() {
return publishEnabled;
}
public void setPublishEnabled(Boolean publishEnabled) {
this.publishEnabled = publishEnabled;
}
}
public static class BatchPublishRequest implements Serializable {
private static final long serialVersionUID = 1L;
private List<BigInteger> ids;
private Boolean publishEnabled;
public List<BigInteger> getIds() {
return ids;
}
public void setIds(List<BigInteger> ids) {
this.ids = ids;
}
public Boolean getPublishEnabled() {
return publishEnabled;
}
public void setPublishEnabled(Boolean publishEnabled) {
this.publishEnabled = publishEnabled;
}
}
}

View File

@@ -20,6 +20,7 @@ import tech.easyflow.ai.entity.base.DocumentCollectionBase;
import tech.easyflow.common.util.PropertiesUtil; import tech.easyflow.common.util.PropertiesUtil;
import tech.easyflow.common.util.StringUtil; import tech.easyflow.common.util.StringUtil;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.permission.resource.VisibilityResource;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.util.Map; import java.util.Map;
@@ -32,7 +33,7 @@ import java.util.Map;
*/ */
@Table("tb_document_collection") @Table("tb_document_collection")
public class DocumentCollection extends DocumentCollectionBase { public class DocumentCollection extends DocumentCollectionBase implements VisibilityResource {
public static final String TYPE_DOCUMENT = "DOCUMENT"; public static final String TYPE_DOCUMENT = "DOCUMENT";
public static final String TYPE_FAQ = "FAQ"; public static final String TYPE_FAQ = "FAQ";
@@ -71,6 +72,10 @@ public class DocumentCollection extends DocumentCollectionBase {
* 是否启用重排模型 * 是否启用重排模型
*/ */
public static final String KEY_RERANK_ENABLE = "rerankEnable"; public static final String KEY_RERANK_ENABLE = "rerankEnable";
public static final String KEY_SPLITTER_DEFAULT_STRATEGY = "splitter.defaultStrategyCode";
public static final String KEY_SPLITTER_AUTO_RECOMMEND_ENABLED = "splitter.autoRecommendEnabled";
public static final String KEY_SPLITTER_FALLBACK_STRATEGY = "splitter.fallbackStrategyCode";
public static final String KEY_SPLITTER_STRATEGY_PROFILES = "splitter.strategyProfiles";
public DocumentStore toDocumentStore() { public DocumentStore toDocumentStore() {
String storeType = this.getVectorStoreType(); String storeType = this.getVectorStoreType();

View File

@@ -4,6 +4,7 @@ import com.easyagents.core.model.chat.tool.Tool;
import com.mybatisflex.annotation.Table; import com.mybatisflex.annotation.Table;
import tech.easyflow.ai.easyagents.tool.WorkflowTool; import tech.easyflow.ai.easyagents.tool.WorkflowTool;
import tech.easyflow.ai.entity.base.WorkflowBase; import tech.easyflow.ai.entity.base.WorkflowBase;
import tech.easyflow.system.permission.resource.VisibilityResource;
/** /**
* 实体类。 * 实体类。
@@ -13,7 +14,7 @@ import tech.easyflow.ai.entity.base.WorkflowBase;
*/ */
@Table("tb_workflow") @Table("tb_workflow")
public class Workflow extends WorkflowBase { public class Workflow extends WorkflowBase implements VisibilityResource {
public Tool toFunction(boolean needEnglishName) { public Tool toFunction(boolean needEnglishName) {
return new WorkflowTool(this, needEnglishName); return new WorkflowTool(this, needEnglishName);

View File

@@ -3,8 +3,10 @@ package tech.easyflow.ai.entity.base;
import com.mybatisflex.annotation.Column; import com.mybatisflex.annotation.Column;
import com.mybatisflex.annotation.Id; import com.mybatisflex.annotation.Id;
import com.mybatisflex.annotation.KeyType; import com.mybatisflex.annotation.KeyType;
import com.mybatisflex.core.handler.FastjsonTypeHandler;
import java.io.Serializable; import java.io.Serializable;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.Map;
public class DocumentChunkBase implements Serializable { public class DocumentChunkBase implements Serializable {
@@ -38,6 +40,12 @@ public class DocumentChunkBase implements Serializable {
@Column(comment = "分割顺序") @Column(comment = "分割顺序")
private Integer sorting; private Integer sorting;
/**
* 扩展元信息
*/
@Column(typeHandler = FastjsonTypeHandler.class, comment = "扩展元信息")
private Map<String, Object> options;
public BigInteger getId() { public BigInteger getId() {
return id; return id;
} }
@@ -78,4 +86,12 @@ public class DocumentChunkBase implements Serializable {
this.sorting = sorting; this.sorting = sorting;
} }
public Map<String, Object> getOptions() {
return options;
}
public void setOptions(Map<String, Object> options) {
this.options = options;
}
} }

View File

@@ -160,6 +160,12 @@ public class DocumentCollectionBase extends DateEntity implements Serializable {
@Column(comment = "分类ID") @Column(comment = "分类ID")
private BigInteger categoryId; private BigInteger categoryId;
/**
* 可见范围
*/
@Column(comment = "可见范围")
private String visibilityScope;
public BigInteger getId() { public BigInteger getId() {
return id; return id;
} }
@@ -352,4 +358,12 @@ public class DocumentCollectionBase extends DateEntity implements Serializable {
this.categoryId = categoryId; this.categoryId = categoryId;
} }
public String getVisibilityScope() {
return visibilityScope;
}
public void setVisibilityScope(String visibilityScope) {
this.visibilityScope = visibilityScope;
}
} }

View File

@@ -152,6 +152,18 @@ public class ModelBase implements Serializable {
@Column(comment = "是否支持tool消息") @Column(comment = "是否支持tool消息")
private Boolean supportToolMessage; private Boolean supportToolMessage;
/**
* 统一模型调用对外标识
*/
@Column(comment = "统一模型调用对外标识")
private String invokeCode;
/**
* 是否开启统一模型调用发布
*/
@Column(comment = "是否开启统一模型调用发布")
private Boolean publishEnabled;
public BigInteger getId() { public BigInteger getId() {
return id; return id;
} }
@@ -336,4 +348,20 @@ public class ModelBase implements Serializable {
this.supportToolMessage = supportToolMessage; this.supportToolMessage = supportToolMessage;
} }
public String getInvokeCode() {
return invokeCode;
}
public void setInvokeCode(String invokeCode) {
this.invokeCode = invokeCode;
}
public Boolean getPublishEnabled() {
return publishEnabled;
}
public void setPublishEnabled(Boolean publishEnabled) {
this.publishEnabled = publishEnabled;
}
} }

View File

@@ -103,6 +103,12 @@ public class WorkflowBase extends DateEntity implements Serializable {
@Column(comment = "分类ID") @Column(comment = "分类ID")
private BigInteger categoryId; private BigInteger categoryId;
/**
* 可见范围
*/
@Column(comment = "可见范围")
private String visibilityScope;
public BigInteger getId() { public BigInteger getId() {
return id; return id;
} }
@@ -223,4 +229,12 @@ public class WorkflowBase extends DateEntity implements Serializable {
this.categoryId = categoryId; this.categoryId = categoryId;
} }
public String getVisibilityScope() {
return visibilityScope;
}
public void setVisibilityScope(String visibilityScope) {
this.visibilityScope = visibilityScope;
}
} }

View File

@@ -0,0 +1,61 @@
package tech.easyflow.ai.invoke.exception;
public class ModelInvokeException extends RuntimeException {
private final int status;
private final String type;
private final String param;
private final String code;
public ModelInvokeException(int status, String message, String type, String param, String code) {
super(message);
this.status = status;
this.type = type;
this.param = param;
this.code = code;
}
public int getStatus() {
return status;
}
public String getType() {
return type;
}
public String getParam() {
return param;
}
public String getCode() {
return code;
}
public static ModelInvokeException badRequest(String message) {
return badRequest(message, null, "invalid_request_error");
}
public static ModelInvokeException badRequest(String message, String param, String code) {
return new ModelInvokeException(400, message, "invalid_request_error", param, code);
}
public static ModelInvokeException unauthorized(String message) {
return new ModelInvokeException(401, message, "authentication_error", null, "unauthorized");
}
public static ModelInvokeException forbidden(String message) {
return new ModelInvokeException(403, message, "permission_error", null, "forbidden");
}
public static ModelInvokeException notFound(String message) {
return new ModelInvokeException(404, message, "not_found_error", "model", "model_not_found");
}
public static ModelInvokeException badGateway(String message) {
return new ModelInvokeException(502, message, "api_error", null, "upstream_bad_gateway");
}
public static ModelInvokeException serviceUnavailable(String message) {
return new ModelInvokeException(503, message, "service_unavailable_error", null, "upstream_unavailable");
}
}

View File

@@ -0,0 +1,651 @@
package tech.easyflow.ai.invoke.mapper;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import tech.easyflow.ai.invoke.model.UnifiedChatChunk;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.model.UnifiedChoice;
import tech.easyflow.ai.invoke.model.UnifiedContentPart;
import tech.easyflow.ai.invoke.model.UnifiedImageUrl;
import tech.easyflow.ai.invoke.model.UnifiedMessage;
import tech.easyflow.ai.invoke.model.UnifiedResponseFormat;
import tech.easyflow.ai.invoke.model.UnifiedTool;
import tech.easyflow.ai.invoke.model.UnifiedToolCall;
import tech.easyflow.ai.invoke.model.UnifiedToolCallFunction;
import tech.easyflow.ai.invoke.model.UnifiedToolFunction;
import tech.easyflow.ai.invoke.model.UnifiedUsage;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionChunkResponse;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionRequest;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionResponse;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@Component
public class OpenAiProtocolMapper {
private static final Set<String> ROOT_FIELDS = Set.of(
"model", "messages", "stream", "temperature", "top_p", "max_tokens",
"seed", "tools", "tool_choice", "response_format"
);
private static final Set<String> MESSAGE_FIELDS = Set.of(
"role", "content", "name", "tool_call_id", "tool_calls"
);
private static final Set<String> CONTENT_PART_FIELDS = Set.of("type", "text", "image_url");
private static final Set<String> IMAGE_URL_FIELDS = Set.of("url", "detail");
private static final Set<String> TOOL_FIELDS = Set.of("type", "function");
private static final Set<String> TOOL_FUNCTION_FIELDS = Set.of("name", "description", "parameters");
private static final Set<String> RESPONSE_FORMAT_FIELDS = Set.of("type", "json_schema");
private static final Set<String> TOOL_CHOICE_FIELDS = Set.of("type", "function");
private final ObjectMapper objectMapper;
public OpenAiProtocolMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}
public OpenAiChatCompletionRequest readRequest(String rawBody) {
JsonNode rootNode;
try {
rootNode = objectMapper.readTree(rawBody);
} catch (JsonProcessingException e) {
throw ModelInvokeException.badRequest("请求体不是合法 JSON", null, "invalid_json");
}
if (rootNode == null || !rootNode.isObject()) {
throw ModelInvokeException.badRequest("请求体必须为 JSON 对象", null, "invalid_json");
}
validateAllowedFields(rootNode, ROOT_FIELDS, null);
validateMessages(rootNode.path("messages"));
validateTools(rootNode.path("tools"));
validateToolChoice(rootNode.path("tool_choice"));
validateResponseFormat(rootNode.path("response_format"));
try {
OpenAiChatCompletionRequest request = buildRequest(rootNode);
if (StrUtil.isBlank(request.getModel())) {
throw ModelInvokeException.badRequest("model 不能为空", "model", "model_required");
}
if (request.getMessages() == null || request.getMessages().isEmpty()) {
throw ModelInvokeException.badRequest("messages 不能为空", "messages", "messages_required");
}
return request;
} catch (IllegalArgumentException e) {
throw ModelInvokeException.badRequest(e.getMessage(), "messages", "invalid_message_content");
}
}
private OpenAiChatCompletionRequest buildRequest(JsonNode rootNode) {
OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest();
request.setModel(textValue(rootNode, "model"));
request.setMessages(buildMessages(rootNode.path("messages")));
if (rootNode.has("stream") && !rootNode.get("stream").isNull()) {
request.setStream(rootNode.get("stream").asBoolean());
}
if (rootNode.has("temperature") && !rootNode.get("temperature").isNull()) {
request.setTemperature(rootNode.get("temperature").asDouble());
}
if (rootNode.has("top_p") && !rootNode.get("top_p").isNull()) {
request.setTopP(rootNode.get("top_p").asDouble());
}
if (rootNode.has("max_tokens") && !rootNode.get("max_tokens").isNull()) {
request.setMaxTokens(rootNode.get("max_tokens").asInt());
}
if (rootNode.has("seed") && !rootNode.get("seed").isNull()) {
request.setSeed(rootNode.get("seed").asLong());
}
if (rootNode.has("tools") && rootNode.get("tools").isArray()) {
request.setTools(objectMapper.convertValue(
rootNode.get("tools"),
new TypeReference<>() {
}
));
}
if (rootNode.has("tool_choice") && !rootNode.get("tool_choice").isNull()) {
request.setToolChoice(rootNode.get("tool_choice"));
}
if (rootNode.has("response_format") && rootNode.get("response_format").isObject()) {
request.setResponseFormat(objectMapper.convertValue(
rootNode.get("response_format"),
OpenAiChatCompletionRequest.ResponseFormat.class
));
}
return request;
}
private List<OpenAiChatCompletionRequest.Message> buildMessages(JsonNode messagesNode) {
List<OpenAiChatCompletionRequest.Message> messages = new ArrayList<>();
for (JsonNode messageNode : messagesNode) {
OpenAiChatCompletionRequest.Message message = new OpenAiChatCompletionRequest.Message();
message.setRole(textValue(messageNode, "role"));
if (messageNode.has("content")) {
message.setContentNode(messageNode.get("content"));
}
message.setName(textValue(messageNode, "name"));
message.setToolCallId(textValue(messageNode, "tool_call_id"));
if (messageNode.has("tool_calls") && messageNode.get("tool_calls").isArray()) {
message.setToolCalls(objectMapper.convertValue(
messageNode.get("tool_calls"),
new TypeReference<>() {
}
));
}
messages.add(message);
}
return messages;
}
public UnifiedChatRequest toUnifiedRequest(OpenAiChatCompletionRequest request) {
UnifiedChatRequest unifiedRequest = new UnifiedChatRequest();
unifiedRequest.setModel(request.getModel());
unifiedRequest.setMessages(toUnifiedMessages(request.getMessages()));
unifiedRequest.setStream(Boolean.TRUE.equals(request.getStream()));
unifiedRequest.setTemperature(request.getTemperature());
unifiedRequest.setTopP(request.getTopP());
unifiedRequest.setMaxTokens(request.getMaxTokens());
unifiedRequest.setSeed(request.getSeed());
unifiedRequest.setTools(toUnifiedTools(request.getTools()));
unifiedRequest.setToolChoice(request.getToolChoice());
unifiedRequest.setResponseFormat(toUnifiedResponseFormat(request.getResponseFormat()));
return unifiedRequest;
}
public OpenAiChatCompletionRequest toOpenAiRequest(UnifiedChatRequest request) {
OpenAiChatCompletionRequest openAiRequest = new OpenAiChatCompletionRequest();
openAiRequest.setModel(request.getModel());
openAiRequest.setMessages(toOpenAiMessages(request.getMessages()));
openAiRequest.setStream(request.getStream());
openAiRequest.setTemperature(request.getTemperature());
openAiRequest.setTopP(request.getTopP());
openAiRequest.setMaxTokens(request.getMaxTokens());
openAiRequest.setSeed(request.getSeed());
openAiRequest.setTools(toOpenAiTools(request.getTools()));
openAiRequest.setToolChoice(request.getToolChoice());
openAiRequest.setResponseFormat(toOpenAiResponseFormat(request.getResponseFormat()));
return openAiRequest;
}
public UnifiedChatResponse toUnifiedResponse(OpenAiChatCompletionResponse response) {
UnifiedChatResponse unifiedResponse = new UnifiedChatResponse();
unifiedResponse.setId(response.getId());
unifiedResponse.setObject(response.getObject());
unifiedResponse.setCreated(response.getCreated());
unifiedResponse.setModel(response.getModel());
unifiedResponse.setChoices(toUnifiedChoices(response.getChoices(), false));
unifiedResponse.setUsage(toUnifiedUsage(response.getUsage()));
return unifiedResponse;
}
public UnifiedChatChunk toUnifiedChunk(OpenAiChatCompletionChunkResponse chunk) {
UnifiedChatChunk unifiedChunk = new UnifiedChatChunk();
unifiedChunk.setId(chunk.getId());
unifiedChunk.setObject(chunk.getObject());
unifiedChunk.setCreated(chunk.getCreated());
unifiedChunk.setModel(chunk.getModel());
unifiedChunk.setChoices(toUnifiedChunkChoices(chunk.getChoices()));
unifiedChunk.setUsage(toUnifiedUsage(chunk.getUsage()));
return unifiedChunk;
}
public OpenAiChatCompletionResponse toOpenAiResponse(UnifiedChatResponse response) {
OpenAiChatCompletionResponse openAiResponse = new OpenAiChatCompletionResponse();
openAiResponse.setId(response.getId());
openAiResponse.setObject(response.getObject());
openAiResponse.setCreated(response.getCreated());
openAiResponse.setModel(response.getModel());
openAiResponse.setChoices(toOpenAiChoices(response.getChoices(), false));
openAiResponse.setUsage(toOpenAiUsage(response.getUsage()));
return openAiResponse;
}
public OpenAiChatCompletionChunkResponse toOpenAiChunk(UnifiedChatChunk chunk) {
OpenAiChatCompletionChunkResponse openAiChunk = new OpenAiChatCompletionChunkResponse();
openAiChunk.setId(chunk.getId());
openAiChunk.setObject(chunk.getObject());
openAiChunk.setCreated(chunk.getCreated());
openAiChunk.setModel(chunk.getModel());
openAiChunk.setChoices(toOpenAiChunkChoices(chunk.getChoices()));
openAiChunk.setUsage(toOpenAiUsage(chunk.getUsage()));
return openAiChunk;
}
private void validateMessages(JsonNode messagesNode) {
if (messagesNode == null || !messagesNode.isArray() || messagesNode.isEmpty()) {
throw ModelInvokeException.badRequest("messages 必须为非空数组", "messages", "messages_required");
}
for (int i = 0; i < messagesNode.size(); i++) {
JsonNode messageNode = messagesNode.get(i);
if (!messageNode.isObject()) {
throw ModelInvokeException.badRequest("messages[" + i + "] 必须为对象", "messages", "invalid_message");
}
validateAllowedFields(messageNode, MESSAGE_FIELDS, "messages[" + i + "]");
String role = textValue(messageNode, "role");
if (!Set.of("system", "developer", "user", "assistant", "tool").contains(role)) {
throw ModelInvokeException.badRequest("messages[" + i + "].role 不受支持", "messages", "invalid_message_role");
}
JsonNode contentNode = messageNode.get("content");
if (contentNode != null && !contentNode.isNull() && !contentNode.isTextual() && !contentNode.isArray()) {
throw ModelInvokeException.badRequest("messages[" + i + "].content 仅支持字符串或数组", "messages", "invalid_message_content");
}
if (contentNode != null && contentNode.isArray()) {
validateContentParts(contentNode, "messages[" + i + "].content");
}
}
}
private void validateContentParts(JsonNode contentNode, String path) {
for (int i = 0; i < contentNode.size(); i++) {
JsonNode partNode = contentNode.get(i);
if (!partNode.isObject()) {
throw ModelInvokeException.badRequest(path + "[" + i + "] 必须为对象", "messages", "invalid_content_part");
}
validateAllowedFields(partNode, CONTENT_PART_FIELDS, path + "[" + i + "]");
String type = textValue(partNode, "type");
if (!Set.of("text", "image_url").contains(type)) {
throw ModelInvokeException.badRequest(path + "[" + i + "].type 不受支持", "messages", "content_part_not_supported");
}
if ("text".equals(type) && StrUtil.isBlank(textValue(partNode, "text"))) {
throw ModelInvokeException.badRequest(path + "[" + i + "].text 不能为空", "messages", "content_text_required");
}
if ("image_url".equals(type)) {
JsonNode imageUrlNode = partNode.get("image_url");
if (imageUrlNode == null || !imageUrlNode.isObject()) {
throw ModelInvokeException.badRequest(path + "[" + i + "].image_url 必须为对象", "messages", "image_url_required");
}
validateAllowedFields(imageUrlNode, IMAGE_URL_FIELDS, path + "[" + i + "].image_url");
if (StrUtil.isBlank(textValue(imageUrlNode, "url"))) {
throw ModelInvokeException.badRequest(path + "[" + i + "].image_url.url 不能为空", "messages", "image_url_required");
}
}
}
}
private void validateTools(JsonNode toolsNode) {
if (isAbsentNode(toolsNode)) {
return;
}
if (!toolsNode.isArray()) {
throw ModelInvokeException.badRequest("tools 必须为数组", "tools", "invalid_tools");
}
for (int i = 0; i < toolsNode.size(); i++) {
JsonNode toolNode = toolsNode.get(i);
if (!toolNode.isObject()) {
throw ModelInvokeException.badRequest("tools[" + i + "] 必须为对象", "tools", "invalid_tools");
}
validateAllowedFields(toolNode, TOOL_FIELDS, "tools[" + i + "]");
JsonNode functionNode = toolNode.get("function");
if (functionNode == null || !functionNode.isObject()) {
throw ModelInvokeException.badRequest("tools[" + i + "].function 必须为对象", "tools", "invalid_tool_function");
}
validateAllowedFields(functionNode, TOOL_FUNCTION_FIELDS, "tools[" + i + "].function");
if (StrUtil.isBlank(textValue(functionNode, "name"))) {
throw ModelInvokeException.badRequest("tools[" + i + "].function.name 不能为空", "tools", "tool_name_required");
}
}
}
private void validateToolChoice(JsonNode toolChoiceNode) {
if (isAbsentNode(toolChoiceNode)) {
return;
}
if (toolChoiceNode.isTextual()) {
return;
}
if (!toolChoiceNode.isObject()) {
throw ModelInvokeException.badRequest("tool_choice 仅支持字符串或对象", "tool_choice", "invalid_tool_choice");
}
validateAllowedFields(toolChoiceNode, TOOL_CHOICE_FIELDS, "tool_choice");
}
private void validateResponseFormat(JsonNode responseFormatNode) {
if (isAbsentNode(responseFormatNode)) {
return;
}
if (!responseFormatNode.isObject()) {
throw ModelInvokeException.badRequest("response_format 必须为对象", "response_format", "invalid_response_format");
}
validateAllowedFields(responseFormatNode, RESPONSE_FORMAT_FIELDS, "response_format");
if (StrUtil.isBlank(textValue(responseFormatNode, "type"))) {
throw ModelInvokeException.badRequest("response_format.type 不能为空", "response_format", "response_format_type_required");
}
}
private boolean isAbsentNode(JsonNode node) {
return node == null || node.isNull() || node.isMissingNode();
}
private void validateAllowedFields(JsonNode node, Set<String> allowedFields, String path) {
Iterator<String> fieldNames = node.fieldNames();
Set<String> unsupportedFields = new LinkedHashSet<>();
while (fieldNames.hasNext()) {
String fieldName = fieldNames.next();
if (!allowedFields.contains(fieldName)) {
unsupportedFields.add(path == null ? fieldName : path + "." + fieldName);
}
}
if (!unsupportedFields.isEmpty()) {
String field = unsupportedFields.iterator().next();
throw ModelInvokeException.badRequest("存在未支持字段: " + field, field, "unsupported_field");
}
}
private String textValue(JsonNode node, String fieldName) {
JsonNode fieldNode = node.get(fieldName);
if (fieldNode == null || fieldNode.isNull()) {
return null;
}
return fieldNode.asText();
}
private List<UnifiedMessage> toUnifiedMessages(List<OpenAiChatCompletionRequest.Message> messages) {
if (messages == null) {
return null;
}
List<UnifiedMessage> result = new ArrayList<>();
for (OpenAiChatCompletionRequest.Message message : messages) {
UnifiedMessage unifiedMessage = new UnifiedMessage();
unifiedMessage.setRole(message.getRole());
unifiedMessage.setContent(message.getContent());
unifiedMessage.setContentParts(toUnifiedContentParts(message.getContentParts()));
unifiedMessage.setName(message.getName());
unifiedMessage.setToolCallId(message.getToolCallId());
unifiedMessage.setToolCalls(toUnifiedToolCalls(message.getToolCalls()));
result.add(unifiedMessage);
}
return result;
}
private List<OpenAiChatCompletionRequest.Message> toOpenAiMessages(List<UnifiedMessage> messages) {
if (messages == null) {
return null;
}
List<OpenAiChatCompletionRequest.Message> result = new ArrayList<>();
for (UnifiedMessage message : messages) {
OpenAiChatCompletionRequest.Message openAiMessage = new OpenAiChatCompletionRequest.Message();
openAiMessage.setRole(message.getRole());
if (message.getContentParts() != null && !message.getContentParts().isEmpty()) {
ObjectNode contentNode = objectMapper.createObjectNode();
contentNode.set("content", objectMapper.valueToTree(toOpenAiContentParts(message.getContentParts())));
openAiMessage.setContentNode(contentNode.get("content"));
} else {
ObjectNode contentNode = objectMapper.createObjectNode();
if (message.getContent() == null) {
contentNode.putNull("content");
} else {
contentNode.put("content", message.getContent());
}
openAiMessage.setContentNode(contentNode.get("content"));
}
openAiMessage.setName(message.getName());
openAiMessage.setToolCallId(message.getToolCallId());
openAiMessage.setToolCalls(toOpenAiToolCalls(message.getToolCalls()));
result.add(openAiMessage);
}
return result;
}
private List<UnifiedContentPart> toUnifiedContentParts(List<OpenAiChatCompletionRequest.ContentPart> contentParts) {
if (contentParts == null) {
return null;
}
List<UnifiedContentPart> result = new ArrayList<>();
for (OpenAiChatCompletionRequest.ContentPart contentPart : contentParts) {
UnifiedContentPart unifiedContentPart = new UnifiedContentPart();
unifiedContentPart.setType(contentPart.getType());
unifiedContentPart.setText(contentPart.getText());
if (contentPart.getImageUrl() != null) {
UnifiedImageUrl imageUrl = new UnifiedImageUrl();
imageUrl.setUrl(contentPart.getImageUrl().getUrl());
imageUrl.setDetail(contentPart.getImageUrl().getDetail());
unifiedContentPart.setImageUrl(imageUrl);
}
result.add(unifiedContentPart);
}
return result;
}
private List<OpenAiChatCompletionRequest.ContentPart> toOpenAiContentParts(List<UnifiedContentPart> contentParts) {
if (contentParts == null) {
return null;
}
List<OpenAiChatCompletionRequest.ContentPart> result = new ArrayList<>();
for (UnifiedContentPart contentPart : contentParts) {
OpenAiChatCompletionRequest.ContentPart openAiPart = new OpenAiChatCompletionRequest.ContentPart();
openAiPart.setType(contentPart.getType());
openAiPart.setText(contentPart.getText());
if (contentPart.getImageUrl() != null) {
OpenAiChatCompletionRequest.ImageUrl imageUrl = new OpenAiChatCompletionRequest.ImageUrl();
imageUrl.setUrl(contentPart.getImageUrl().getUrl());
imageUrl.setDetail(contentPart.getImageUrl().getDetail());
openAiPart.setImageUrl(imageUrl);
}
result.add(openAiPart);
}
return result;
}
private List<UnifiedTool> toUnifiedTools(List<OpenAiChatCompletionRequest.Tool> tools) {
if (tools == null) {
return null;
}
List<UnifiedTool> result = new ArrayList<>();
for (OpenAiChatCompletionRequest.Tool tool : tools) {
UnifiedTool unifiedTool = new UnifiedTool();
unifiedTool.setType(tool.getType());
if (tool.getFunction() != null) {
UnifiedToolFunction function = new UnifiedToolFunction();
function.setName(tool.getFunction().getName());
function.setDescription(tool.getFunction().getDescription());
function.setParameters(tool.getFunction().getParameters());
unifiedTool.setFunction(function);
}
result.add(unifiedTool);
}
return result;
}
private List<OpenAiChatCompletionRequest.Tool> toOpenAiTools(List<UnifiedTool> tools) {
if (tools == null) {
return null;
}
List<OpenAiChatCompletionRequest.Tool> result = new ArrayList<>();
for (UnifiedTool tool : tools) {
OpenAiChatCompletionRequest.Tool openAiTool = new OpenAiChatCompletionRequest.Tool();
openAiTool.setType(tool.getType());
if (tool.getFunction() != null) {
OpenAiChatCompletionRequest.ToolFunction function = new OpenAiChatCompletionRequest.ToolFunction();
function.setName(tool.getFunction().getName());
function.setDescription(tool.getFunction().getDescription());
function.setParameters(tool.getFunction().getParameters());
openAiTool.setFunction(function);
}
result.add(openAiTool);
}
return result;
}
private UnifiedResponseFormat toUnifiedResponseFormat(OpenAiChatCompletionRequest.ResponseFormat responseFormat) {
if (responseFormat == null) {
return null;
}
UnifiedResponseFormat unifiedResponseFormat = new UnifiedResponseFormat();
unifiedResponseFormat.setType(responseFormat.getType());
unifiedResponseFormat.setJsonSchema(responseFormat.getJsonSchema());
return unifiedResponseFormat;
}
private OpenAiChatCompletionRequest.ResponseFormat toOpenAiResponseFormat(UnifiedResponseFormat responseFormat) {
if (responseFormat == null) {
return null;
}
OpenAiChatCompletionRequest.ResponseFormat openAiResponseFormat = new OpenAiChatCompletionRequest.ResponseFormat();
openAiResponseFormat.setType(responseFormat.getType());
openAiResponseFormat.setJsonSchema(responseFormat.getJsonSchema());
return openAiResponseFormat;
}
private List<UnifiedChoice> toUnifiedChoices(List<OpenAiChatCompletionResponse.Choice> choices, boolean delta) {
if (choices == null) {
return null;
}
List<UnifiedChoice> result = new ArrayList<>();
for (OpenAiChatCompletionResponse.Choice choice : choices) {
UnifiedChoice unifiedChoice = new UnifiedChoice();
unifiedChoice.setIndex(choice.getIndex());
UnifiedMessage message = new UnifiedMessage();
if (choice.getMessage() != null) {
message.setRole(choice.getMessage().getRole());
message.setContent(choice.getMessage().getContent());
message.setToolCallId(choice.getMessage().getToolCallId());
message.setToolCalls(toUnifiedToolCalls(choice.getMessage().getToolCalls()));
}
if (delta) {
unifiedChoice.setDelta(message);
} else {
unifiedChoice.setMessage(message);
}
unifiedChoice.setFinishReason(choice.getFinishReason());
result.add(unifiedChoice);
}
return result;
}
private List<UnifiedChoice> toUnifiedChunkChoices(List<OpenAiChatCompletionChunkResponse.Choice> choices) {
if (choices == null) {
return null;
}
List<UnifiedChoice> result = new ArrayList<>();
for (OpenAiChatCompletionChunkResponse.Choice choice : choices) {
UnifiedChoice unifiedChoice = new UnifiedChoice();
unifiedChoice.setIndex(choice.getIndex());
UnifiedMessage delta = new UnifiedMessage();
if (choice.getDelta() != null) {
delta.setRole(choice.getDelta().getRole());
delta.setContent(choice.getDelta().getContent());
delta.setToolCalls(toUnifiedToolCalls(choice.getDelta().getToolCalls()));
}
unifiedChoice.setDelta(delta);
unifiedChoice.setFinishReason(choice.getFinishReason());
result.add(unifiedChoice);
}
return result;
}
private List<OpenAiChatCompletionResponse.Choice> toOpenAiChoices(List<UnifiedChoice> choices, boolean delta) {
if (choices == null) {
return null;
}
List<OpenAiChatCompletionResponse.Choice> result = new ArrayList<>();
for (UnifiedChoice choice : choices) {
OpenAiChatCompletionResponse.Choice openAiChoice = new OpenAiChatCompletionResponse.Choice();
openAiChoice.setIndex(choice.getIndex());
UnifiedMessage source = delta ? choice.getDelta() : choice.getMessage();
if (source != null) {
OpenAiChatCompletionResponse.Message message = new OpenAiChatCompletionResponse.Message();
message.setRole(source.getRole());
message.setContent(source.getContent());
message.setToolCallId(source.getToolCallId());
message.setToolCalls(toOpenAiToolCalls(source.getToolCalls()));
openAiChoice.setMessage(message);
}
openAiChoice.setFinishReason(choice.getFinishReason());
result.add(openAiChoice);
}
return result;
}
private List<OpenAiChatCompletionChunkResponse.Choice> toOpenAiChunkChoices(List<UnifiedChoice> choices) {
if (choices == null) {
return null;
}
List<OpenAiChatCompletionChunkResponse.Choice> result = new ArrayList<>();
for (UnifiedChoice choice : choices) {
OpenAiChatCompletionChunkResponse.Choice openAiChoice = new OpenAiChatCompletionChunkResponse.Choice();
openAiChoice.setIndex(choice.getIndex());
if (choice.getDelta() != null) {
OpenAiChatCompletionChunkResponse.Delta delta = new OpenAiChatCompletionChunkResponse.Delta();
delta.setRole(choice.getDelta().getRole());
delta.setContent(choice.getDelta().getContent());
delta.setToolCalls(toOpenAiToolCalls(choice.getDelta().getToolCalls()));
openAiChoice.setDelta(delta);
}
openAiChoice.setFinishReason(choice.getFinishReason());
result.add(openAiChoice);
}
return result;
}
private List<UnifiedToolCall> toUnifiedToolCalls(List<OpenAiChatCompletionResponse.ToolCall> toolCalls) {
if (toolCalls == null) {
return null;
}
List<UnifiedToolCall> result = new ArrayList<>();
for (OpenAiChatCompletionResponse.ToolCall toolCall : toolCalls) {
UnifiedToolCall unifiedToolCall = new UnifiedToolCall();
unifiedToolCall.setIndex(toolCall.getIndex());
unifiedToolCall.setId(toolCall.getId());
unifiedToolCall.setType(toolCall.getType());
if (toolCall.getFunction() != null) {
UnifiedToolCallFunction function = new UnifiedToolCallFunction();
function.setName(toolCall.getFunction().getName());
function.setArguments(toolCall.getFunction().getArguments());
unifiedToolCall.setFunction(function);
}
result.add(unifiedToolCall);
}
return result;
}
private List<OpenAiChatCompletionResponse.ToolCall> toOpenAiToolCalls(List<UnifiedToolCall> toolCalls) {
if (toolCalls == null) {
return null;
}
List<OpenAiChatCompletionResponse.ToolCall> result = new ArrayList<>();
for (UnifiedToolCall toolCall : toolCalls) {
OpenAiChatCompletionResponse.ToolCall openAiToolCall = new OpenAiChatCompletionResponse.ToolCall();
openAiToolCall.setIndex(toolCall.getIndex());
openAiToolCall.setId(toolCall.getId());
openAiToolCall.setType(toolCall.getType());
if (toolCall.getFunction() != null) {
OpenAiChatCompletionResponse.ToolCallFunction function = new OpenAiChatCompletionResponse.ToolCallFunction();
function.setName(toolCall.getFunction().getName());
function.setArguments(toolCall.getFunction().getArguments());
openAiToolCall.setFunction(function);
}
result.add(openAiToolCall);
}
return result;
}
private UnifiedUsage toUnifiedUsage(OpenAiChatCompletionResponse.Usage usage) {
if (usage == null) {
return null;
}
UnifiedUsage unifiedUsage = new UnifiedUsage();
unifiedUsage.setPromptTokens(usage.getPromptTokens());
unifiedUsage.setCompletionTokens(usage.getCompletionTokens());
unifiedUsage.setTotalTokens(usage.getTotalTokens());
return unifiedUsage;
}
private OpenAiChatCompletionResponse.Usage toOpenAiUsage(UnifiedUsage usage) {
if (usage == null) {
return null;
}
OpenAiChatCompletionResponse.Usage openAiUsage = new OpenAiChatCompletionResponse.Usage();
openAiUsage.setPromptTokens(usage.getPromptTokens());
openAiUsage.setCompletionTokens(usage.getCompletionTokens());
openAiUsage.setTotalTokens(usage.getTotalTokens());
return openAiUsage;
}
}

View File

@@ -0,0 +1,61 @@
package tech.easyflow.ai.invoke.model;
import java.util.List;
public class UnifiedChatChunk {
private String id;
private String object;
private Long created;
private String model;
private List<UnifiedChoice> choices;
private UnifiedUsage usage;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getObject() {
return object;
}
public void setObject(String object) {
this.object = object;
}
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<UnifiedChoice> getChoices() {
return choices;
}
public void setChoices(List<UnifiedChoice> choices) {
this.choices = choices;
}
public UnifiedUsage getUsage() {
return usage;
}
public void setUsage(UnifiedUsage usage) {
this.usage = usage;
}
}

View File

@@ -0,0 +1,99 @@
package tech.easyflow.ai.invoke.model;
import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
public class UnifiedChatRequest {
private String model;
private List<UnifiedMessage> messages;
private Boolean stream;
private Double temperature;
private Double topP;
private Integer maxTokens;
private Long seed;
private List<UnifiedTool> tools;
private JsonNode toolChoice;
private UnifiedResponseFormat responseFormat;
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<UnifiedMessage> getMessages() {
return messages;
}
public void setMessages(List<UnifiedMessage> messages) {
this.messages = messages;
}
public Boolean getStream() {
return stream;
}
public void setStream(Boolean stream) {
this.stream = stream;
}
public Double getTemperature() {
return temperature;
}
public void setTemperature(Double temperature) {
this.temperature = temperature;
}
public Double getTopP() {
return topP;
}
public void setTopP(Double topP) {
this.topP = topP;
}
public Integer getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
public Long getSeed() {
return seed;
}
public void setSeed(Long seed) {
this.seed = seed;
}
public List<UnifiedTool> getTools() {
return tools;
}
public void setTools(List<UnifiedTool> tools) {
this.tools = tools;
}
public JsonNode getToolChoice() {
return toolChoice;
}
public void setToolChoice(JsonNode toolChoice) {
this.toolChoice = toolChoice;
}
public UnifiedResponseFormat getResponseFormat() {
return responseFormat;
}
public void setResponseFormat(UnifiedResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}
}

View File

@@ -0,0 +1,61 @@
package tech.easyflow.ai.invoke.model;
import java.util.List;
public class UnifiedChatResponse {
private String id;
private String object;
private Long created;
private String model;
private List<UnifiedChoice> choices;
private UnifiedUsage usage;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getObject() {
return object;
}
public void setObject(String object) {
this.object = object;
}
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<UnifiedChoice> getChoices() {
return choices;
}
public void setChoices(List<UnifiedChoice> choices) {
this.choices = choices;
}
public UnifiedUsage getUsage() {
return usage;
}
public void setUsage(UnifiedUsage usage) {
this.usage = usage;
}
}

View File

@@ -0,0 +1,41 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedChoice {
private Integer index;
private UnifiedMessage message;
private UnifiedMessage delta;
private String finishReason;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public UnifiedMessage getMessage() {
return message;
}
public void setMessage(UnifiedMessage message) {
this.message = message;
}
public UnifiedMessage getDelta() {
return delta;
}
public void setDelta(UnifiedMessage delta) {
this.delta = delta;
}
public String getFinishReason() {
return finishReason;
}
public void setFinishReason(String finishReason) {
this.finishReason = finishReason;
}
}

View File

@@ -0,0 +1,32 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedContentPart {
private String type;
private String text;
private UnifiedImageUrl imageUrl;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public UnifiedImageUrl getImageUrl() {
return imageUrl;
}
public void setImageUrl(UnifiedImageUrl imageUrl) {
this.imageUrl = imageUrl;
}
}

View File

@@ -0,0 +1,23 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedImageUrl {
private String url;
private String detail;
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getDetail() {
return detail;
}
public void setDetail(String detail) {
this.detail = detail;
}
}

View File

@@ -0,0 +1,61 @@
package tech.easyflow.ai.invoke.model;
import java.util.List;
public class UnifiedMessage {
private String role;
private String content;
private List<UnifiedContentPart> contentParts;
private String name;
private String toolCallId;
private List<UnifiedToolCall> toolCalls;
public String getRole() {
return role;
}
public void setRole(String role) {
this.role = role;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public List<UnifiedContentPart> getContentParts() {
return contentParts;
}
public void setContentParts(List<UnifiedContentPart> contentParts) {
this.contentParts = contentParts;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getToolCallId() {
return toolCallId;
}
public void setToolCallId(String toolCallId) {
this.toolCallId = toolCallId;
}
public List<UnifiedToolCall> getToolCalls() {
return toolCalls;
}
public void setToolCalls(List<UnifiedToolCall> toolCalls) {
this.toolCalls = toolCalls;
}
}

View File

@@ -0,0 +1,25 @@
package tech.easyflow.ai.invoke.model;
import com.fasterxml.jackson.databind.JsonNode;
public class UnifiedResponseFormat {
private String type;
private JsonNode jsonSchema;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public JsonNode getJsonSchema() {
return jsonSchema;
}
public void setJsonSchema(JsonNode jsonSchema) {
this.jsonSchema = jsonSchema;
}
}

View File

@@ -0,0 +1,23 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedTool {
private String type;
private UnifiedToolFunction function;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public UnifiedToolFunction getFunction() {
return function;
}
public void setFunction(UnifiedToolFunction function) {
this.function = function;
}
}

View File

@@ -0,0 +1,41 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedToolCall {
private Integer index;
private String id;
private String type;
private UnifiedToolCallFunction function;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public UnifiedToolCallFunction getFunction() {
return function;
}
public void setFunction(UnifiedToolCallFunction function) {
this.function = function;
}
}

View File

@@ -0,0 +1,23 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedToolCallFunction {
private String name;
private String arguments;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getArguments() {
return arguments;
}
public void setArguments(String arguments) {
this.arguments = arguments;
}
}

View File

@@ -0,0 +1,34 @@
package tech.easyflow.ai.invoke.model;
import com.fasterxml.jackson.databind.JsonNode;
public class UnifiedToolFunction {
private String name;
private String description;
private JsonNode parameters;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getDescription() {
return description;
}
public void setDescription(String description) {
this.description = description;
}
public JsonNode getParameters() {
return parameters;
}
public void setParameters(JsonNode parameters) {
this.parameters = parameters;
}
}

View File

@@ -0,0 +1,32 @@
package tech.easyflow.ai.invoke.model;
public class UnifiedUsage {
private Integer promptTokens;
private Integer completionTokens;
private Integer totalTokens;
public Integer getPromptTokens() {
return promptTokens;
}
public void setPromptTokens(Integer promptTokens) {
this.promptTokens = promptTokens;
}
public Integer getCompletionTokens() {
return completionTokens;
}
public void setCompletionTokens(Integer completionTokens) {
this.completionTokens = completionTokens;
}
public Integer getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(Integer totalTokens) {
this.totalTokens = totalTokens;
}
}

View File

@@ -0,0 +1,134 @@
package tech.easyflow.ai.invoke.protocol.openai;
import com.alibaba.fastjson.annotation.JSONField;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
@JsonIgnoreProperties(ignoreUnknown = true)
public class OpenAiChatCompletionChunkResponse {
private String id;
private String object;
private Long created;
private String model;
private List<Choice> choices;
private OpenAiChatCompletionResponse.Usage usage;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getObject() {
return object;
}
public void setObject(String object) {
this.object = object;
}
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<Choice> getChoices() {
return choices;
}
public void setChoices(List<Choice> choices) {
this.choices = choices;
}
public OpenAiChatCompletionResponse.Usage getUsage() {
return usage;
}
public void setUsage(OpenAiChatCompletionResponse.Usage usage) {
this.usage = usage;
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Choice {
private Integer index;
private Delta delta;
@JSONField(name = "finish_reason")
@JsonProperty("finish_reason")
private String finishReason;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public Delta getDelta() {
return delta;
}
public void setDelta(Delta delta) {
this.delta = delta;
}
public String getFinishReason() {
return finishReason;
}
public void setFinishReason(String finishReason) {
this.finishReason = finishReason;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Delta {
private String role;
private String content;
@JSONField(name = "tool_calls")
@JsonProperty("tool_calls")
private List<OpenAiChatCompletionResponse.ToolCall> toolCalls;
public String getRole() {
return role;
}
public void setRole(String role) {
this.role = role;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public List<OpenAiChatCompletionResponse.ToolCall> getToolCalls() {
return toolCalls;
}
public void setToolCalls(List<OpenAiChatCompletionResponse.ToolCall> toolCalls) {
this.toolCalls = toolCalls;
}
}
}

View File

@@ -0,0 +1,335 @@
package tech.easyflow.ai.invoke.protocol.openai;
import com.alibaba.fastjson.annotation.JSONField;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.JsonNode;
import java.util.List;
@JsonIgnoreProperties(ignoreUnknown = true)
public class OpenAiChatCompletionRequest {
private String model;
private List<Message> messages;
private Boolean stream;
private Double temperature;
@JSONField(name = "top_p")
@JsonProperty("top_p")
private Double topP;
@JSONField(name = "max_tokens")
@JsonProperty("max_tokens")
private Integer maxTokens;
private Long seed;
private List<Tool> tools;
@JSONField(name = "tool_choice")
@JsonProperty("tool_choice")
private JsonNode toolChoice;
@JSONField(name = "response_format")
@JsonProperty("response_format")
private ResponseFormat responseFormat;
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<Message> getMessages() {
return messages;
}
public void setMessages(List<Message> messages) {
this.messages = messages;
}
public Boolean getStream() {
return stream;
}
public void setStream(Boolean stream) {
this.stream = stream;
}
public Double getTemperature() {
return temperature;
}
public void setTemperature(Double temperature) {
this.temperature = temperature;
}
public Double getTopP() {
return topP;
}
public void setTopP(Double topP) {
this.topP = topP;
}
public Integer getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
public Long getSeed() {
return seed;
}
public void setSeed(Long seed) {
this.seed = seed;
}
public List<Tool> getTools() {
return tools;
}
public void setTools(List<Tool> tools) {
this.tools = tools;
}
public JsonNode getToolChoice() {
return toolChoice;
}
public void setToolChoice(JsonNode toolChoice) {
this.toolChoice = toolChoice;
}
public ResponseFormat getResponseFormat() {
return responseFormat;
}
public void setResponseFormat(ResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Message {
private String role;
private String content;
private List<ContentPart> contentParts;
private String name;
@JSONField(name = "tool_call_id")
@JsonProperty("tool_call_id")
private String toolCallId;
@JSONField(name = "tool_calls")
@JsonProperty("tool_calls")
private List<OpenAiChatCompletionResponse.ToolCall> toolCalls;
@JsonProperty("content")
public void setContentNode(JsonNode contentNode) {
if (contentNode == null || contentNode.isNull()) {
this.content = null;
this.contentParts = null;
return;
}
if (contentNode.isTextual()) {
this.content = contentNode.asText();
this.contentParts = null;
return;
}
if (contentNode.isArray()) {
this.content = null;
this.contentParts = OpenAiJsonSupport.convertContentParts(contentNode);
return;
}
throw new IllegalArgumentException("message.content 仅支持字符串或数组");
}
@JSONField(name = "content")
@JsonProperty("content")
public Object getContentNode() {
if (contentParts != null) {
return contentParts;
}
return content;
}
public String getRole() {
return role;
}
public void setRole(String role) {
this.role = role;
}
@JsonIgnore
public String getContent() {
return content;
}
@JsonIgnore
public List<ContentPart> getContentParts() {
return contentParts;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getToolCallId() {
return toolCallId;
}
public void setToolCallId(String toolCallId) {
this.toolCallId = toolCallId;
}
public List<OpenAiChatCompletionResponse.ToolCall> getToolCalls() {
return toolCalls;
}
public void setToolCalls(List<OpenAiChatCompletionResponse.ToolCall> toolCalls) {
this.toolCalls = toolCalls;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ContentPart {
private String type;
private String text;
@JSONField(name = "image_url")
@JsonProperty("image_url")
private ImageUrl imageUrl;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public ImageUrl getImageUrl() {
return imageUrl;
}
public void setImageUrl(ImageUrl imageUrl) {
this.imageUrl = imageUrl;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ImageUrl {
private String url;
private String detail;
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getDetail() {
return detail;
}
public void setDetail(String detail) {
this.detail = detail;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Tool {
private String type;
private ToolFunction function;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public ToolFunction getFunction() {
return function;
}
public void setFunction(ToolFunction function) {
this.function = function;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ToolFunction {
private String name;
private String description;
private JsonNode parameters;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getDescription() {
return description;
}
public void setDescription(String description) {
this.description = description;
}
public JsonNode getParameters() {
return parameters;
}
public void setParameters(JsonNode parameters) {
this.parameters = parameters;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ResponseFormat {
private String type;
@JSONField(name = "json_schema")
@JsonProperty("json_schema")
private JsonNode jsonSchema;
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public JsonNode getJsonSchema() {
return jsonSchema;
}
public void setJsonSchema(JsonNode jsonSchema) {
this.jsonSchema = jsonSchema;
}
}
}

View File

@@ -0,0 +1,247 @@
package tech.easyflow.ai.invoke.protocol.openai;
import com.alibaba.fastjson.annotation.JSONField;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
@JsonIgnoreProperties(ignoreUnknown = true)
public class OpenAiChatCompletionResponse {
private String id;
private String object;
private Long created;
private String model;
private List<Choice> choices;
private Usage usage;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getObject() {
return object;
}
public void setObject(String object) {
this.object = object;
}
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<Choice> getChoices() {
return choices;
}
public void setChoices(List<Choice> choices) {
this.choices = choices;
}
public Usage getUsage() {
return usage;
}
public void setUsage(Usage usage) {
this.usage = usage;
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Choice {
private Integer index;
private Message message;
@JSONField(name = "finish_reason")
@JsonProperty("finish_reason")
private String finishReason;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public Message getMessage() {
return message;
}
public void setMessage(Message message) {
this.message = message;
}
public String getFinishReason() {
return finishReason;
}
public void setFinishReason(String finishReason) {
this.finishReason = finishReason;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Message {
private String role;
private String content;
@JSONField(name = "tool_call_id")
@JsonProperty("tool_call_id")
private String toolCallId;
@JSONField(name = "tool_calls")
@JsonProperty("tool_calls")
private List<ToolCall> toolCalls;
public String getRole() {
return role;
}
public void setRole(String role) {
this.role = role;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getToolCallId() {
return toolCallId;
}
public void setToolCallId(String toolCallId) {
this.toolCallId = toolCallId;
}
public List<ToolCall> getToolCalls() {
return toolCalls;
}
public void setToolCalls(List<ToolCall> toolCalls) {
this.toolCalls = toolCalls;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ToolCall {
private Integer index;
private String id;
private String type;
private ToolCallFunction function;
public Integer getIndex() {
return index;
}
public void setIndex(Integer index) {
this.index = index;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public ToolCallFunction getFunction() {
return function;
}
public void setFunction(ToolCallFunction function) {
this.function = function;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class ToolCallFunction {
private String name;
private String arguments;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getArguments() {
return arguments;
}
public void setArguments(String arguments) {
this.arguments = arguments;
}
}
@JsonIgnoreProperties(ignoreUnknown = true)
public static class Usage {
@JSONField(name = "prompt_tokens")
@JsonProperty("prompt_tokens")
private Integer promptTokens;
@JSONField(name = "completion_tokens")
@JsonProperty("completion_tokens")
private Integer completionTokens;
@JSONField(name = "total_tokens")
@JsonProperty("total_tokens")
private Integer totalTokens;
public Integer getPromptTokens() {
return promptTokens;
}
public void setPromptTokens(Integer promptTokens) {
this.promptTokens = promptTokens;
}
public Integer getCompletionTokens() {
return completionTokens;
}
public void setCompletionTokens(Integer completionTokens) {
this.completionTokens = completionTokens;
}
public Integer getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(Integer totalTokens) {
this.totalTokens = totalTokens;
}
}
}

View File

@@ -0,0 +1,54 @@
package tech.easyflow.ai.invoke.protocol.openai;
public class OpenAiErrorResponse {
private Error error;
public Error getError() {
return error;
}
public void setError(Error error) {
this.error = error;
}
public static class Error {
private String message;
private String type;
private String param;
private String code;
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getParam() {
return param;
}
public void setParam(String param) {
this.param = param;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
}
}

View File

@@ -0,0 +1,20 @@
package tech.easyflow.ai.invoke.protocol.openai;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
final class OpenAiJsonSupport {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private OpenAiJsonSupport() {
}
static List<OpenAiChatCompletionRequest.ContentPart> convertContentParts(JsonNode node) {
return OBJECT_MAPPER.convertValue(node, new TypeReference<>() {
});
}
}

View File

@@ -0,0 +1,14 @@
package tech.easyflow.ai.invoke.provider;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
public interface ModelProviderGateway {
boolean supports(String providerType);
UnifiedChatResponse chat(Model model, UnifiedChatRequest request);
void chatStream(Model model, UnifiedChatRequest request, UnifiedChatChunkObserver observer);
}

View File

@@ -0,0 +1,27 @@
package tech.easyflow.ai.invoke.provider;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import java.util.List;
@Component
public class ModelProviderGatewayRegistry {
private final List<ModelProviderGateway> gateways;
public ModelProviderGatewayRegistry(List<ModelProviderGateway> gateways) {
this.gateways = gateways;
}
public ModelProviderGateway getGateway(String providerType) {
return gateways.stream()
.filter(gateway -> gateway.supports(providerType))
.findFirst()
.orElseThrow(() -> ModelInvokeException.badRequest(
"当前 providerType 暂不支持统一模型调用: " + providerType,
"model",
"provider_not_supported"
));
}
}

View File

@@ -0,0 +1,14 @@
package tech.easyflow.ai.invoke.provider;
import tech.easyflow.ai.invoke.model.UnifiedChatChunk;
public interface UnifiedChatChunkObserver {
void onChunk(UnifiedChatChunk chunk);
default void onComplete() {
}
default void onError(Throwable throwable) {
}
}

View File

@@ -0,0 +1,129 @@
package tech.easyflow.ai.invoke.provider.base;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.BufferedSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import tech.easyflow.ai.invoke.mapper.OpenAiProtocolMapper;
import tech.easyflow.ai.invoke.model.UnifiedChatChunk;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionChunkResponse;
import tech.easyflow.common.util.OkHttpClientUtil;
import java.io.IOException;
import java.util.Objects;
public abstract class AbstractOpenAiCompatibleGateway {
private static final Logger log = LoggerFactory.getLogger(AbstractOpenAiCompatibleGateway.class);
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
protected final ObjectMapper objectMapper;
protected final OpenAiProtocolMapper openAiProtocolMapper;
protected final ThreadPoolTaskExecutor sseThreadPool;
protected AbstractOpenAiCompatibleGateway(ObjectMapper objectMapper,
OpenAiProtocolMapper openAiProtocolMapper,
ThreadPoolTaskExecutor sseThreadPool) {
this.objectMapper = objectMapper;
this.openAiProtocolMapper = openAiProtocolMapper;
this.sseThreadPool = sseThreadPool;
}
protected Response executePost(Model model, Object requestBody) {
try {
String body = objectMapper.writeValueAsString(requestBody);
Request request = new Request.Builder()
.url(buildUrl(model))
.addHeader("Authorization", "Bearer " + model.checkAndGetApiKey())
.addHeader("Content-Type", "application/json")
.post(RequestBody.create(body, JSON_TYPE))
.build();
return OkHttpClientUtil.buildDefaultClient().newCall(request).execute();
} catch (ModelInvokeException e) {
throw e;
} catch (Exception e) {
throw ModelInvokeException.badGateway("调用上游模型失败: " + e.getMessage());
}
}
protected String buildUrl(Model model) {
return StrUtil.removeSuffix(model.checkAndGetEndpoint(), "/")
+ "/"
+ StrUtil.removePrefix(model.checkAndGetRequestPath(), "/");
}
protected void validateResponse(Response response) throws IOException {
if (response.isSuccessful()) {
return;
}
String message = extractUpstreamErrorMessage(response);
int code = response.code();
if (code >= 500 || code == 429) {
throw ModelInvokeException.serviceUnavailable(message);
}
throw ModelInvokeException.badGateway(message);
}
protected String extractUpstreamErrorMessage(Response response) throws IOException {
ResponseBody body = response.body();
String bodyString = body == null ? "" : body.string();
if (StrUtil.isBlank(bodyString)) {
return "上游模型调用失败HTTP " + response.code();
}
try {
JsonNode root = objectMapper.readTree(bodyString);
String errorMessage = root.path("error").path("message").asText();
if (StrUtil.isNotBlank(errorMessage)) {
return "上游模型调用失败HTTP " + response.code() + ": " + errorMessage;
}
} catch (Exception ignored) {
}
return "上游模型调用失败HTTP " + response.code() + ": " + bodyString;
}
protected void streamResponse(Response response,
tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver observer) {
sseThreadPool.execute(() -> {
try (Response closeableResponse = response; ResponseBody body = closeableResponse.body()) {
if (body == null) {
throw new IOException("上游流式响应体为空");
}
BufferedSource source = body.source();
while (!source.exhausted()) {
String line = source.readUtf8Line();
if (line == null) {
break;
}
if (line.isBlank() || !line.startsWith("data:")) {
continue;
}
String payload = line.substring(5).trim();
if (Objects.equals(payload, "[DONE]")) {
observer.onComplete();
return;
}
OpenAiChatCompletionChunkResponse chunkResponse = objectMapper.readValue(
payload,
OpenAiChatCompletionChunkResponse.class
);
UnifiedChatChunk chunk = openAiProtocolMapper.toUnifiedChunk(chunkResponse);
observer.onChunk(chunk);
}
observer.onComplete();
} catch (Exception e) {
log.error("streamResponse error: {}", e.getMessage(), e);
observer.onError(e);
}
});
}
}

View File

@@ -0,0 +1,78 @@
package tech.easyflow.ai.invoke.provider.openai;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.Response;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.invoke.mapper.OpenAiProtocolMapper;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionRequest;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionResponse;
import tech.easyflow.ai.invoke.provider.ModelProviderGateway;
import tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver;
import tech.easyflow.ai.invoke.provider.base.AbstractOpenAiCompatibleGateway;
import java.util.Set;
@Component
public class OpenAiGateway extends AbstractOpenAiCompatibleGateway implements ModelProviderGateway {
private static final Set<String> OPENAI_COMPATIBLE_PROVIDER_TYPES = Set.of(
"openai",
"deepseek",
"aliyun",
"zhipu",
"minimax",
"kimi",
"siliconlow",
"ollama",
"self-hosted"
);
public OpenAiGateway(ObjectMapper objectMapper,
OpenAiProtocolMapper openAiProtocolMapper,
ThreadPoolTaskExecutor sseThreadPool) {
super(objectMapper, openAiProtocolMapper, sseThreadPool);
}
@Override
public boolean supports(String providerType) {
return providerType != null && OPENAI_COMPATIBLE_PROVIDER_TYPES.contains(providerType.toLowerCase());
}
@Override
public UnifiedChatResponse chat(Model model, UnifiedChatRequest request) {
OpenAiChatCompletionRequest openAiRequest = openAiProtocolMapper.toOpenAiRequest(request);
try (Response response = executePost(model, openAiRequest)) {
validateResponse(response);
if (response.body() == null) {
throw tech.easyflow.ai.invoke.exception.ModelInvokeException.badGateway("上游模型响应为空");
}
OpenAiChatCompletionResponse responseBody = objectMapper.readValue(
response.body().string(),
OpenAiChatCompletionResponse.class
);
return openAiProtocolMapper.toUnifiedResponse(responseBody);
} catch (tech.easyflow.ai.invoke.exception.ModelInvokeException e) {
throw e;
} catch (Exception e) {
throw tech.easyflow.ai.invoke.exception.ModelInvokeException.badGateway("调用上游模型失败: " + e.getMessage());
}
}
@Override
public void chatStream(Model model, UnifiedChatRequest request, UnifiedChatChunkObserver observer) {
OpenAiChatCompletionRequest openAiRequest = openAiProtocolMapper.toOpenAiRequest(request);
Response response = executePost(model, openAiRequest);
try {
validateResponse(response);
streamResponse(response, observer);
} catch (RuntimeException | java.io.IOException e) {
response.close();
throw e instanceof RuntimeException ? (RuntimeException) e
: tech.easyflow.ai.invoke.exception.ModelInvokeException.badGateway("调用上游模型失败: " + e.getMessage());
}
}
}

View File

@@ -0,0 +1,12 @@
package tech.easyflow.ai.invoke.service;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver;
public interface UnifiedModelInvokeService {
UnifiedChatResponse chat(UnifiedChatRequest request);
void chatStream(UnifiedChatRequest request, UnifiedChatChunkObserver observer);
}

View File

@@ -0,0 +1,114 @@
package tech.easyflow.ai.invoke.service.impl;
import cn.hutool.core.util.StrUtil;
import org.springframework.stereotype.Service;
import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.model.UnifiedContentPart;
import tech.easyflow.ai.invoke.model.UnifiedMessage;
import tech.easyflow.ai.invoke.provider.ModelProviderGateway;
import tech.easyflow.ai.invoke.provider.ModelProviderGatewayRegistry;
import tech.easyflow.ai.invoke.provider.UnifiedChatChunkObserver;
import tech.easyflow.ai.invoke.service.UnifiedModelInvokeService;
import tech.easyflow.ai.service.ModelService;
import java.util.List;
@Service
public class UnifiedModelInvokeServiceImpl implements UnifiedModelInvokeService {
private final ModelService modelService;
private final ModelProviderGatewayRegistry gatewayRegistry;
public UnifiedModelInvokeServiceImpl(ModelService modelService,
ModelProviderGatewayRegistry gatewayRegistry) {
this.modelService = modelService;
this.gatewayRegistry = gatewayRegistry;
}
@Override
public UnifiedChatResponse chat(UnifiedChatRequest request) {
Model model = resolveModel(request);
validateRequestAgainstModel(model, request);
ModelProviderGateway gateway = gatewayRegistry.getGateway(model.getModelProvider().getProviderType());
return gateway.chat(model, request);
}
@Override
public void chatStream(UnifiedChatRequest request, UnifiedChatChunkObserver observer) {
Model model = resolveModel(request);
validateRequestAgainstModel(model, request);
ModelProviderGateway gateway = gatewayRegistry.getGateway(model.getModelProvider().getProviderType());
gateway.chatStream(model, request, observer);
}
private Model resolveModel(UnifiedChatRequest request) {
if (request == null || StrUtil.isBlank(request.getModel())) {
throw ModelInvokeException.badRequest("model 不能为空", "model", "model_required");
}
Model model = modelService.getModelInstanceByInvokeCode(request.getModel());
if (model == null) {
throw ModelInvokeException.notFound("未找到可调用模型: " + request.getModel());
}
if (!Boolean.TRUE.equals(model.getPublishEnabled())) {
throw ModelInvokeException.badRequest("当前模型未开启 API 调用发布", "model", "model_not_published");
}
if (!Model.MODEL_TYPES[0].equals(model.getModelType())) {
throw ModelInvokeException.badRequest("当前模型不是 chatModel无法通过 chat/completions 调用", "model", "model_type_mismatch");
}
if (model.getModelProvider() == null || StrUtil.isBlank(model.getModelProvider().getProviderType())) {
throw ModelInvokeException.badRequest("当前模型缺少 providerType 配置", "model", "provider_type_missing");
}
return model;
}
private void validateRequestAgainstModel(Model model, UnifiedChatRequest request) {
List<UnifiedMessage> messages = request.getMessages();
if (messages == null || messages.isEmpty()) {
throw ModelInvokeException.badRequest("messages 不能为空", "messages", "messages_required");
}
if (hasImageInput(messages)) {
if (!Boolean.TRUE.equals(model.getSupportImage())) {
throw ModelInvokeException.badRequest("当前模型不支持图片输入", "messages", "image_not_supported");
}
if (Boolean.TRUE.equals(model.getSupportImageB64Only()) && hasNonBase64Image(messages)) {
throw ModelInvokeException.badRequest("当前模型仅支持 base64 图片输入", "messages", "image_base64_only");
}
}
if (request.getTools() != null && !request.getTools().isEmpty() && !Boolean.TRUE.equals(model.getSupportTool())) {
throw ModelInvokeException.badRequest("当前模型不支持 tools 参数", "tools", "tool_not_supported");
}
if (hasToolMessage(messages) && !Boolean.TRUE.equals(model.getSupportToolMessage())) {
throw ModelInvokeException.badRequest("当前模型不支持 tool 消息透传", "messages", "tool_message_not_supported");
}
}
private boolean hasImageInput(List<UnifiedMessage> messages) {
return messages.stream()
.map(UnifiedMessage::getContentParts)
.filter(parts -> parts != null && !parts.isEmpty())
.flatMap(List::stream)
.anyMatch(part -> "image_url".equals(part.getType()));
}
private boolean hasNonBase64Image(List<UnifiedMessage> messages) {
return messages.stream()
.map(UnifiedMessage::getContentParts)
.filter(parts -> parts != null && !parts.isEmpty())
.flatMap(List::stream)
.filter(part -> "image_url".equals(part.getType()))
.map(UnifiedContentPart::getImageUrl)
.filter(imageUrl -> imageUrl != null && StrUtil.isNotBlank(imageUrl.getUrl()))
.anyMatch(imageUrl -> !StrUtil.startWithIgnoreCase(imageUrl.getUrl(), "data:"));
}
private boolean hasToolMessage(List<UnifiedMessage> messages) {
return messages.stream().anyMatch(message ->
StrUtil.equals(message.getRole(), "tool")
|| (message.getToolCalls() != null && !message.getToolCalls().isEmpty())
|| StrUtil.isNotBlank(message.getToolCallId())
);
}
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentChunk;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.service.DocumentChunkService;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class DocumentChunkIdKnowledgeResourceAccessResolver implements ResourceAccessResolver {
@Resource
private DocumentChunkService documentChunkService;
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.DOCUMENT_CHUNK_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("文档分段不存在");
}
DocumentChunk documentChunk = documentChunkService.getById(String.valueOf(identifier));
if (documentChunk == null) {
throw new BusinessException("文档分段不存在");
}
DocumentCollection collection = documentCollectionService.getById(documentChunk.getDocumentCollectionId());
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.Document;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.ai.service.DocumentService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class DocumentIdKnowledgeResourceAccessResolver implements ResourceAccessResolver {
@Resource
private DocumentService documentService;
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.DOCUMENT_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("文档不存在");
}
Document document = documentService.getById(String.valueOf(identifier));
if (document == null) {
throw new BusinessException("文档不存在");
}
DocumentCollection collection = documentCollectionService.getById(document.getCollectionId());
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.FaqCategory;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.ai.service.FaqCategoryService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class FaqCategoryIdKnowledgeResourceAccessResolver implements ResourceAccessResolver {
@Resource
private FaqCategoryService faqCategoryService;
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.FAQ_CATEGORY_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("FAQ分类不存在");
}
FaqCategory faqCategory = faqCategoryService.getById(String.valueOf(identifier));
if (faqCategory == null) {
throw new BusinessException("FAQ分类不存在");
}
DocumentCollection collection = documentCollectionService.getById(faqCategory.getCollectionId());
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.entity.FaqItem;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.ai.service.FaqItemService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class FaqItemIdKnowledgeResourceAccessResolver implements ResourceAccessResolver {
@Resource
private FaqItemService faqItemService;
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.FAQ_ITEM_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("FAQ不存在");
}
FaqItem faqItem = faqItemService.getById(String.valueOf(identifier));
if (faqItem == null) {
throw new BusinessException("FAQ不存在");
}
DocumentCollection collection = documentCollectionService.getById(faqItem.getCollectionId());
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,36 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class KnowledgeIdOrSlugResourceAccessResolver implements ResourceAccessResolver {
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.KNOWLEDGE_ID_OR_SLUG == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("知识库不存在");
}
DocumentCollection collection = documentCollectionService.getDetail(String.valueOf(identifier));
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,36 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.ai.service.DocumentCollectionService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class KnowledgeIdResourceAccessResolver implements ResourceAccessResolver {
@Resource
private DocumentCollectionService documentCollectionService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.KNOWLEDGE == resourceType && ResourceLookup.KNOWLEDGE_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("知识库不存在");
}
DocumentCollection collection = documentCollectionService.getById(String.valueOf(identifier));
if (collection == null) {
throw new BusinessException("知识库不存在");
}
return new ResolvedResourceAccess(collection, null);
}
}

View File

@@ -0,0 +1,41 @@
package tech.easyflow.ai.permission;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import java.math.BigInteger;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
public class KnowledgeReadAccessSnapshot {
private final RoleCategoryAccessSnapshot categoryAccess;
private final Set<BigInteger> readableDeptIds;
public KnowledgeReadAccessSnapshot(RoleCategoryAccessSnapshot categoryAccess, Set<BigInteger> readableDeptIds) {
this.categoryAccess = categoryAccess;
this.readableDeptIds = readableDeptIds == null
? Collections.emptySet()
: Collections.unmodifiableSet(new LinkedHashSet<>(readableDeptIds));
}
public BigInteger getAccountId() {
return categoryAccess == null ? null : categoryAccess.getAccountId();
}
public boolean isSuperAdmin() {
return categoryAccess != null && categoryAccess.isSuperAdmin();
}
public boolean isRestricted() {
return categoryAccess == null || categoryAccess.isRestricted();
}
public Set<BigInteger> getCategoryIds() {
return categoryAccess == null ? Collections.emptySet() : categoryAccess.getCategoryIds();
}
public Set<BigInteger> getReadableDeptIds() {
return readableDeptIds;
}
}

View File

@@ -0,0 +1,101 @@
package tech.easyflow.ai.permission;
import com.mybatisflex.core.query.QueryCondition;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.DocumentCollection;
import tech.easyflow.common.entity.LoginAccount;
import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.VisibilityScope;
import tech.easyflow.system.service.CategoryPermissionService;
import tech.easyflow.system.service.SysDeptService;
import javax.annotation.Resource;
import java.math.BigInteger;
import java.util.Collections;
import java.util.Set;
import static tech.easyflow.ai.entity.table.DocumentCollectionTableDef.DOCUMENT_COLLECTION;
@Component
public class KnowledgeVisibilityQueryHelper {
@Resource
private CategoryPermissionService categoryPermissionService;
@Resource
private SysDeptService sysDeptService;
public KnowledgeReadAccessSnapshot getCurrentReadSnapshot() {
RoleCategoryAccessSnapshot categoryAccess = categoryPermissionService.getCurrentAccess(CategoryResourceType.KNOWLEDGE.getCode());
if (categoryAccess.isSuperAdmin()) {
return new KnowledgeReadAccessSnapshot(categoryAccess, Collections.emptySet());
}
LoginAccount loginAccount = SaTokenUtil.getLoginAccount();
Set<BigInteger> deptIds = loginAccount == null
? Collections.emptySet()
: sysDeptService.getSelfAndAncestorDeptIds(loginAccount.getDeptId());
return new KnowledgeReadAccessSnapshot(categoryAccess, deptIds);
}
public void applyReadableAccess(QueryWrapper queryWrapper) {
applyReadableAccess(queryWrapper, getCurrentReadSnapshot());
}
public void applyReadableAccess(QueryWrapper queryWrapper, KnowledgeReadAccessSnapshot snapshot) {
if (snapshot.isSuperAdmin()) {
return;
}
BigInteger accountId = snapshot.getAccountId();
if (accountId == null) {
queryWrapper.eq("id", BigInteger.valueOf(-1));
return;
}
QueryCondition ownerCondition = DOCUMENT_COLLECTION.CREATED_BY.eq(accountId);
if (snapshot.isRestricted() && snapshot.getCategoryIds().isEmpty()) {
queryWrapper.and(ownerCondition);
return;
}
QueryCondition visibilityCondition = DOCUMENT_COLLECTION.VISIBILITY_SCOPE.eq(VisibilityScope.PUBLIC.name());
if (!snapshot.getReadableDeptIds().isEmpty()) {
visibilityCondition = visibilityCondition.or(
DOCUMENT_COLLECTION.VISIBILITY_SCOPE.eq(VisibilityScope.DEPT.name())
.and(DOCUMENT_COLLECTION.DEPT_ID.in(snapshot.getReadableDeptIds()))
);
}
QueryCondition readableCondition = visibilityCondition;
if (snapshot.isRestricted()) {
readableCondition = DOCUMENT_COLLECTION.CATEGORY_ID.in(snapshot.getCategoryIds()).and(visibilityCondition);
}
queryWrapper.and(ownerCondition.or(readableCondition));
}
public boolean canRead(DocumentCollection collection) {
return canRead(collection, getCurrentReadSnapshot());
}
public boolean canRead(DocumentCollection collection, KnowledgeReadAccessSnapshot snapshot) {
if (collection == null) {
return false;
}
if (snapshot.isSuperAdmin()) {
return true;
}
BigInteger accountId = snapshot.getAccountId();
if (accountId != null && accountId.equals(collection.getCreatedBy())) {
return true;
}
if (snapshot.isRestricted() && (collection.getCategoryId() == null || !snapshot.getCategoryIds().contains(collection.getCategoryId()))) {
return false;
}
VisibilityScope scope = VisibilityScope.fromOrDefault(collection.getVisibilityScope(), VisibilityScope.PRIVATE);
if (VisibilityScope.PUBLIC == scope) {
return true;
}
return VisibilityScope.DEPT == scope
&& collection.getDeptId() != null
&& snapshot.getReadableDeptIds().contains(collection.getDeptId());
}
}

View File

@@ -0,0 +1,45 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.Workflow;
import tech.easyflow.ai.entity.WorkflowExecResult;
import tech.easyflow.ai.service.WorkflowExecResultService;
import tech.easyflow.ai.service.WorkflowService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class WorkflowExecKeyResourceAccessResolver implements ResourceAccessResolver {
@Resource
private WorkflowExecResultService workflowExecResultService;
@Resource
private WorkflowService workflowService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.WORKFLOW == resourceType && ResourceLookup.EXEC_KEY == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("工作流执行记录不存在");
}
WorkflowExecResult execResult = workflowExecResultService.getByExecKey(String.valueOf(identifier));
if (execResult == null) {
throw new BusinessException("工作流执行记录不存在");
}
Workflow workflow = workflowService.getDetail(String.valueOf(execResult.getWorkflowId()));
if (workflow == null) {
throw new BusinessException("工作流不存在");
}
return new ResolvedResourceAccess(workflow, execResult.getCreatedBy());
}
}

View File

@@ -0,0 +1,36 @@
package tech.easyflow.ai.permission;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.Workflow;
import tech.easyflow.ai.service.WorkflowService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.ResourceLookup;
import tech.easyflow.system.permission.resource.ResolvedResourceAccess;
import tech.easyflow.system.permission.resource.ResourceAccessResolver;
import javax.annotation.Resource;
@Component
public class WorkflowIdResourceAccessResolver implements ResourceAccessResolver {
@Resource
private WorkflowService workflowService;
@Override
public boolean supports(CategoryResourceType resourceType, ResourceLookup lookup) {
return CategoryResourceType.WORKFLOW == resourceType && ResourceLookup.WORKFLOW_ID == lookup;
}
@Override
public ResolvedResourceAccess resolve(Object identifier) {
if (identifier == null) {
throw new BusinessException("工作流不存在");
}
Workflow workflow = workflowService.getDetail(String.valueOf(identifier));
if (workflow == null) {
throw new BusinessException("工作流不存在");
}
return new ResolvedResourceAccess(workflow, null);
}
}

View File

@@ -0,0 +1,41 @@
package tech.easyflow.ai.permission;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import java.math.BigInteger;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
public class WorkflowReadAccessSnapshot {
private final RoleCategoryAccessSnapshot categoryAccess;
private final Set<BigInteger> readableDeptIds;
public WorkflowReadAccessSnapshot(RoleCategoryAccessSnapshot categoryAccess, Set<BigInteger> readableDeptIds) {
this.categoryAccess = categoryAccess;
this.readableDeptIds = readableDeptIds == null
? Collections.emptySet()
: Collections.unmodifiableSet(new LinkedHashSet<>(readableDeptIds));
}
public BigInteger getAccountId() {
return categoryAccess == null ? null : categoryAccess.getAccountId();
}
public boolean isSuperAdmin() {
return categoryAccess != null && categoryAccess.isSuperAdmin();
}
public boolean isRestricted() {
return categoryAccess == null || categoryAccess.isRestricted();
}
public Set<BigInteger> getCategoryIds() {
return categoryAccess == null ? Collections.emptySet() : categoryAccess.getCategoryIds();
}
public Set<BigInteger> getReadableDeptIds() {
return readableDeptIds;
}
}

View File

@@ -0,0 +1,101 @@
package tech.easyflow.ai.permission;
import com.mybatisflex.core.query.QueryCondition;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.stereotype.Component;
import tech.easyflow.ai.entity.Workflow;
import tech.easyflow.common.entity.LoginAccount;
import tech.easyflow.common.satoken.util.SaTokenUtil;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.enums.CategoryResourceType;
import tech.easyflow.system.enums.VisibilityScope;
import tech.easyflow.system.service.CategoryPermissionService;
import tech.easyflow.system.service.SysDeptService;
import javax.annotation.Resource;
import java.math.BigInteger;
import java.util.Collections;
import java.util.Set;
import static tech.easyflow.ai.entity.table.WorkflowTableDef.WORKFLOW;
@Component
public class WorkflowVisibilityQueryHelper {
@Resource
private CategoryPermissionService categoryPermissionService;
@Resource
private SysDeptService sysDeptService;
public WorkflowReadAccessSnapshot getCurrentReadSnapshot() {
RoleCategoryAccessSnapshot categoryAccess = categoryPermissionService.getCurrentAccess(CategoryResourceType.WORKFLOW.getCode());
if (categoryAccess.isSuperAdmin()) {
return new WorkflowReadAccessSnapshot(categoryAccess, Collections.emptySet());
}
LoginAccount loginAccount = SaTokenUtil.getLoginAccount();
Set<BigInteger> deptIds = loginAccount == null
? Collections.emptySet()
: sysDeptService.getSelfAndAncestorDeptIds(loginAccount.getDeptId());
return new WorkflowReadAccessSnapshot(categoryAccess, deptIds);
}
public void applyReadableAccess(QueryWrapper queryWrapper) {
applyReadableAccess(queryWrapper, getCurrentReadSnapshot());
}
public void applyReadableAccess(QueryWrapper queryWrapper, WorkflowReadAccessSnapshot snapshot) {
if (snapshot.isSuperAdmin()) {
return;
}
BigInteger accountId = snapshot.getAccountId();
if (accountId == null) {
queryWrapper.eq("id", BigInteger.valueOf(-1));
return;
}
QueryCondition ownerCondition = WORKFLOW.CREATED_BY.eq(accountId);
if (snapshot.isRestricted() && snapshot.getCategoryIds().isEmpty()) {
queryWrapper.and(ownerCondition);
return;
}
QueryCondition visibilityCondition = WORKFLOW.VISIBILITY_SCOPE.eq(VisibilityScope.PUBLIC.name());
if (!snapshot.getReadableDeptIds().isEmpty()) {
visibilityCondition = visibilityCondition.or(
WORKFLOW.VISIBILITY_SCOPE.eq(VisibilityScope.DEPT.name())
.and(WORKFLOW.DEPT_ID.in(snapshot.getReadableDeptIds()))
);
}
QueryCondition readableCondition = visibilityCondition;
if (snapshot.isRestricted()) {
readableCondition = WORKFLOW.CATEGORY_ID.in(snapshot.getCategoryIds()).and(visibilityCondition);
}
queryWrapper.and(ownerCondition.or(readableCondition));
}
public boolean canRead(Workflow workflow) {
return canRead(workflow, getCurrentReadSnapshot());
}
public boolean canRead(Workflow workflow, WorkflowReadAccessSnapshot snapshot) {
if (workflow == null) {
return false;
}
if (snapshot.isSuperAdmin()) {
return true;
}
BigInteger accountId = snapshot.getAccountId();
if (accountId != null && accountId.equals(workflow.getCreatedBy())) {
return true;
}
if (snapshot.isRestricted() && (workflow.getCategoryId() == null || !snapshot.getCategoryIds().contains(workflow.getCategoryId()))) {
return false;
}
VisibilityScope scope = VisibilityScope.fromOrDefault(workflow.getVisibilityScope(), VisibilityScope.PRIVATE);
if (VisibilityScope.PUBLIC == scope) {
return true;
}
return VisibilityScope.DEPT == scope
&& workflow.getDeptId() != null
&& snapshot.getReadableDeptIds().contains(workflow.getDeptId());
}
}

View File

@@ -3,6 +3,7 @@ package tech.easyflow.ai.service;
import tech.easyflow.ai.entity.Document; import tech.easyflow.ai.entity.Document;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.service.IService; import com.mybatisflex.core.service.IService;
import tech.easyflow.ai.documentimport.DocumentImportDtos;
import tech.easyflow.ai.entity.DocumentChunk; import tech.easyflow.ai.entity.DocumentChunk;
import tech.easyflow.ai.entity.DocumentCollectionSplitParams; import tech.easyflow.ai.entity.DocumentCollectionSplitParams;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
@@ -25,4 +26,10 @@ public interface DocumentService extends IService<Document> {
Result textSplit(DocumentCollectionSplitParams documentCollectionSplitParams); Result textSplit(DocumentCollectionSplitParams documentCollectionSplitParams);
Result saveTextResult(List<DocumentChunk> documentChunks, Document document); Result saveTextResult(List<DocumentChunk> documentChunks, Document document);
Result<DocumentImportDtos.AnalyzeResponse> analyzeImport(DocumentImportDtos.AnalyzeRequest request);
Result<DocumentImportDtos.PreviewResponse> previewImport(DocumentImportDtos.PreviewRequest request);
Result<DocumentImportDtos.CommitResponse> commitImport(DocumentImportDtos.CommitRequest request);
} }

View File

@@ -24,4 +24,16 @@ public interface ModelService extends IService<Model> {
void removeByEntity(Model entity); void removeByEntity(Model entity);
Model getModelInstance(BigInteger modelId); Model getModelInstance(BigInteger modelId);
Model getModelInstanceByInvokeCode(String invokeCode);
void validateForSaveOrUpdate(Model entity, boolean isSave);
List<Model> listInvokeModels();
List<Model> listSelectableModels(Model entity, Boolean asTree, String sortKey, String sortType);
Model updateInvokeConfig(BigInteger id, String invokeCode, Boolean publishEnabled);
List<Model> batchUpdateInvokePublishStatus(List<BigInteger> ids, Boolean publishEnabled);
} }

View File

@@ -5,7 +5,6 @@ import tech.easyflow.ai.entity.PluginCategory;
import tech.easyflow.ai.entity.PluginCategoryMapping; import tech.easyflow.ai.entity.PluginCategoryMapping;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
@@ -16,7 +15,7 @@ import java.util.List;
*/ */
public interface PluginCategoryMappingService extends IService<PluginCategoryMapping> { public interface PluginCategoryMappingService extends IService<PluginCategoryMapping> {
boolean updateRelation(BigInteger pluginId, ArrayList<BigInteger> categoryIds); boolean updateRelation(BigInteger pluginId, List<BigInteger> categoryIds);
List<PluginCategory> getPluginCategories(BigInteger pluginId); List<PluginCategory> getPluginCategories(BigInteger pluginId);
} }

View File

@@ -14,7 +14,7 @@ import java.util.List;
*/ */
public interface PluginService extends IService<Plugin> { public interface PluginService extends IService<Plugin> {
boolean savePlugin(Plugin plugin); Plugin savePlugin(Plugin plugin);
boolean removePlugin(String id); boolean removePlugin(String id);

View File

@@ -0,0 +1,13 @@
package tech.easyflow.ai.service;
import java.math.BigInteger;
import java.util.Set;
public interface PluginVisibilityService {
Set<BigInteger> getCurrentVisiblePluginIds();
boolean canAccessPlugin(Long createdBy, BigInteger pluginId);
void assertPluginVisible(Long createdBy, BigInteger pluginId, String message);
}

View File

@@ -41,6 +41,7 @@ import tech.easyflow.common.util.UrlEncoderUtil;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter; import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
import tech.easyflow.core.chat.protocol.sse.ChatSseUtil; import tech.easyflow.core.chat.protocol.sse.ChatSseUtil;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
@@ -107,6 +108,8 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
private McpService mcpService; private McpService mcpService;
@Resource(name = "default") @Resource(name = "default")
FileStorageService storageService; FileStorageService storageService;
@Resource
private CategoryPermissionService categoryPermissionService;
@Override @Override
public Bot getDetail(String id) { public Bot getDetail(String id) {
@@ -144,6 +147,13 @@ public class BotServiceImpl extends ServiceImpl<BotMapper, Bot> implements BotSe
if (aiBot == null) { if (aiBot == null) {
return ChatSseUtil.sendSystemError(conversationId, "聊天助手不存在"); return ChatSseUtil.sendSystemError(conversationId, "聊天助手不存在");
} }
if (StpUtil.isLogin()) {
try {
categoryPermissionService.assertCategoryResourceVisible("BOT", aiBot.getCreatedBy(), aiBot.getCategoryId(), "无权限访问聊天助手");
} catch (BusinessException e) {
return ChatSseUtil.sendSystemError(conversationId, e.getMessage());
}
}
if (aiBot.getModelId() == null) { if (aiBot.getModelId() == null) {
return ChatSseUtil.sendSystemError(conversationId, "请配置大模型!"); return ChatSseUtil.sendSystemError(conversationId, "请配置大模型!");
} }

View File

@@ -12,6 +12,12 @@ import com.easyagents.core.model.embedding.EmbeddingOptions;
import com.easyagents.core.store.DocumentStore; import com.easyagents.core.store.DocumentStore;
import com.easyagents.core.store.StoreOptions; import com.easyagents.core.store.StoreOptions;
import com.easyagents.core.store.StoreResult; import com.easyagents.core.store.StoreResult;
import com.easyagents.rag.core.RagChunk;
import com.easyagents.rag.core.RagDefaults;
import com.easyagents.rag.core.RagStrategyCodes;
import com.easyagents.rag.ingestion.RagIngestionService;
import com.easyagents.rag.ingestion.model.AnalysisResult;
import com.easyagents.rag.ingestion.model.StrategyConfig;
import com.easyagents.search.engine.service.DocumentSearcher; import com.easyagents.search.engine.service.DocumentSearcher;
import com.mybatisflex.core.keygen.impl.FlexIDKeyGenerator; import com.mybatisflex.core.keygen.impl.FlexIDKeyGenerator;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
@@ -24,6 +30,9 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import tech.easyflow.ai.config.SearcherFactory; import tech.easyflow.ai.config.SearcherFactory;
import tech.easyflow.ai.documentimport.DocumentImportDtos;
import tech.easyflow.ai.documentimport.DocumentImportKeys;
import tech.easyflow.ai.documentimport.DocumentImportPreviewService;
import tech.easyflow.ai.entity.*; import tech.easyflow.ai.entity.*;
import tech.easyflow.ai.mapper.DocumentChunkMapper; import tech.easyflow.ai.mapper.DocumentChunkMapper;
import tech.easyflow.ai.mapper.DocumentMapper; import tech.easyflow.ai.mapper.DocumentMapper;
@@ -42,6 +51,7 @@ import javax.annotation.Resource;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.math.BigInteger; import java.math.BigInteger;
import java.math.BigDecimal;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@@ -81,6 +91,12 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
@Autowired @Autowired
private SearcherFactory searcherFactory; private SearcherFactory searcherFactory;
@Autowired
private RagIngestionService ragIngestionService;
@Autowired
private DocumentImportPreviewService documentImportPreviewService;
@Override @Override
public Page<Document> getDocumentList(String knowledgeId, int pageSize, int pageNum, String fileName) { public Page<Document> getDocumentList(String knowledgeId, int pageSize, int pageNum, String fileName) {
QueryWrapper queryWrapper=QueryWrapper.create() QueryWrapper queryWrapper=QueryWrapper.create()
@@ -250,23 +266,397 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
return Result.fail(1, "切割结果无有效文本,无法进行向量化"); return Result.fail(1, "切割结果无有效文本,无法进行向量化");
} }
Boolean result = storeDocument(document, validChunks); StoreExecutionContext storeContext = prepareStoreContext(document);
if (result) { storeDocumentChunks(storeContext, validChunks);
this.getMapper().insert(document); try {
AtomicInteger sort = new AtomicInteger(1); persistDocumentWithChunks(document, validChunks);
validChunks.forEach(item -> { updateKnowledgeAfterStore(storeContext);
item.setDocumentCollectionId(document.getCollectionId());
item.setSorting(sort.get());
item.setDocumentId(document.getId());
sort.getAndIncrement();
documentChunkService.save(item);
});
return Result.ok(); return Result.ok();
} catch (Exception e) {
cleanupPersistedDocument(document);
rollbackStoredChunks(storeContext, validChunks);
Log.error("保存文档失败: documentId={}, title={}", document.getId(), document.getTitle(), e);
throw new BusinessException("保存失败:" + e.getMessage());
} }
return Result.fail(1, "保存失败");
} }
protected Boolean storeDocument(Document entity, List<DocumentChunk> documentChunks) { protected Boolean storeDocument(Document entity, List<DocumentChunk> documentChunks) {
StoreExecutionContext storeContext = prepareStoreContext(entity);
storeDocumentChunks(storeContext, documentChunks);
updateKnowledgeAfterStore(storeContext);
return true;
}
@Override
public Result<DocumentImportDtos.AnalyzeResponse> analyzeImport(DocumentImportDtos.AnalyzeRequest request) {
DocumentCollection knowledge = assertDocumentCollection(request.getKnowledgeId());
if (request.getFiles() == null || request.getFiles().isEmpty()) {
throw new BusinessException("请先上传文件");
}
List<DocumentImportDtos.AnalyzeItem> items = new ArrayList<>();
for (DocumentImportDtos.FileItem file : request.getFiles()) {
AnalysisResult analysis = analyzeSingleFile(file.getFilePath(), file.getFileName());
StrategyConfig strategyConfig = resolveStrategyConfig(
knowledge,
null,
analysis
);
DocumentImportDtos.AnalyzeItem item = new DocumentImportDtos.AnalyzeItem();
item.setFilePath(file.getFilePath());
item.setFileName(file.getFileName());
item.setAnalysis(analysis);
item.setStrategyConfig(strategyConfig);
items.add(item);
}
DocumentImportDtos.AnalyzeResponse response = new DocumentImportDtos.AnalyzeResponse();
response.setItems(items);
response.setTotal(items.size());
return Result.ok(response);
}
@Override
public Result<DocumentImportDtos.PreviewResponse> previewImport(DocumentImportDtos.PreviewRequest request) {
DocumentCollection knowledge = assertDocumentCollection(request.getKnowledgeId());
if (request.getFiles() == null || request.getFiles().isEmpty()) {
throw new BusinessException("请先上传文件");
}
List<DocumentImportDtos.PreviewFileResult> items = new ArrayList<>();
int totalChunks = 0;
for (DocumentImportDtos.PreviewFileRequest file : request.getFiles()) {
DocumentImportDtos.PreviewSession session = buildPreviewSession(knowledge, file);
String sessionId = documentImportPreviewService.put(session);
DocumentImportDtos.PreviewFileResult item = new DocumentImportDtos.PreviewFileResult();
item.setPreviewSessionId(sessionId);
item.setFilePath(file.getFilePath());
item.setFileName(file.getFileName());
item.setStrategyCode(session.getStrategyConfig().getStrategyCode());
item.setStrategyLabel(ragIngestionService.toStrategyLabel(session.getStrategyConfig().getStrategyCode()));
item.setAnalysis(session.getAnalysis());
item.setChunks(session.getPreviewChunks());
item.setTotalChunks(session.getPreviewChunks().size());
item.setTotalWarnings(countWarnings(session.getPreviewChunks()));
items.add(item);
totalChunks += session.getPreviewChunks().size();
}
DocumentImportDtos.PreviewResponse response = new DocumentImportDtos.PreviewResponse();
response.setItems(items);
response.setTotalFiles(items.size());
response.setTotalChunks(totalChunks);
return Result.ok(response);
}
@Override
public Result<DocumentImportDtos.CommitResponse> commitImport(DocumentImportDtos.CommitRequest request) {
DocumentCollection knowledge = assertDocumentCollection(request.getKnowledgeId());
if (request.getPreviewSessionIds() == null || request.getPreviewSessionIds().isEmpty()) {
throw new BusinessException("请选择需要提交的预览结果");
}
List<DocumentImportDtos.CommitFileResult> results = new ArrayList<>();
int successCount = 0;
int errorCount = 0;
for (String previewSessionId : request.getPreviewSessionIds()) {
DocumentImportDtos.CommitFileResult result = new DocumentImportDtos.CommitFileResult();
result.setPreviewSessionId(previewSessionId);
try {
DocumentImportDtos.PreviewSession session = documentImportPreviewService.getRequired(previewSessionId);
if (!Objects.equals(session.getKnowledgeId(), knowledge.getId())) {
throw new BusinessException("预览会话与当前知识库不匹配");
}
commitSingleSession(session);
result.setSuccess(true);
result.setFileName(session.getFileName());
result.setDocumentId(session.getDocument().getId());
result.setChunkCount(session.getDocumentChunks().size());
documentImportPreviewService.remove(previewSessionId);
successCount++;
} catch (Exception e) {
result.setSuccess(false);
result.setReason(e.getMessage());
errorCount++;
}
results.add(result);
}
DocumentImportDtos.CommitResponse response = new DocumentImportDtos.CommitResponse();
response.setTotalFiles(results.size());
response.setSuccessCount(successCount);
response.setErrorCount(errorCount);
response.setResults(results);
return Result.ok(response);
}
private void commitSingleSession(DocumentImportDtos.PreviewSession session) {
Document document = session.getDocument();
document.setCreated(new Date());
document.setModified(new Date());
document.setCreatedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
document.setModifiedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
for (DocumentChunk chunk : session.getDocumentChunks()) {
chunk.setDocumentId(document.getId());
chunk.setDocumentCollectionId(document.getCollectionId());
}
StoreExecutionContext storeContext = prepareStoreContext(document);
storeDocumentChunks(storeContext, session.getDocumentChunks());
try {
persistDocumentWithChunks(document, session.getDocumentChunks());
updateKnowledgeAfterStore(storeContext);
} catch (Exception e) {
cleanupPersistedDocument(document);
rollbackStoredChunks(storeContext, session.getDocumentChunks());
throw new BusinessException("提交导入失败:" + e.getMessage());
}
}
private DocumentImportDtos.PreviewSession buildPreviewSession(DocumentCollection knowledge,
DocumentImportDtos.PreviewFileRequest fileRequest) {
AnalysisResult analysis = analyzeSingleFile(fileRequest.getFilePath(), fileRequest.getFileName());
StrategyConfig strategyConfig = resolveStrategyConfig(knowledge, fileRequest.getStrategyConfig(), analysis);
List<RagChunk> previewChunks = ragIngestionService.split(analysis, strategyConfig);
if (previewChunks.isEmpty()) {
throw new BusinessException("未生成有效分块,请调整策略后重试");
}
FlexIDKeyGenerator flexIDKeyGenerator = new FlexIDKeyGenerator();
Document document = buildPreviewDocument(flexIDKeyGenerator, knowledge, fileRequest, analysis, strategyConfig);
List<DocumentChunk> documentChunks = buildDocumentChunks(flexIDKeyGenerator, document, previewChunks);
DocumentImportDtos.PreviewSession session = new DocumentImportDtos.PreviewSession();
session.setKnowledgeId(knowledge.getId());
session.setFilePath(fileRequest.getFilePath());
session.setFileName(fileRequest.getFileName());
session.setSourceFormat(analysis.getSourceFormat());
session.setStrategyConfig(strategyConfig);
session.setAnalysis(analysis);
session.setDocument(document);
session.setDocumentChunks(documentChunks);
session.setPreviewChunks(previewChunks);
session.setCreatedAt(new Date());
return session;
}
private Document buildPreviewDocument(FlexIDKeyGenerator flexIDKeyGenerator,
DocumentCollection knowledge,
DocumentImportDtos.PreviewFileRequest fileRequest,
AnalysisResult analysis,
StrategyConfig strategyConfig) {
Document document = new Document();
document.setId(new BigInteger(String.valueOf(flexIDKeyGenerator.generate(document, null))));
document.setCollectionId(knowledge.getId());
document.setDocumentType(analysis.getSourceFormat());
document.setDocumentPath(fileRequest.getFilePath());
document.setTitle(fileRequest.getFileName());
document.setContent(analysis.getNormalizedContent());
document.setCreated(new Date());
document.setModified(new Date());
document.setModifiedBy(BigInteger.valueOf(StpUtil.getLoginIdAsLong()));
Map<String, Object> options = new LinkedHashMap<>();
options.put(DocumentImportKeys.KEY_DOCUMENT_STRATEGY_CODE, strategyConfig.getStrategyCode());
options.put(DocumentImportKeys.KEY_DOCUMENT_STRATEGY_LABEL, ragIngestionService.toStrategyLabel(strategyConfig.getStrategyCode()));
options.put(DocumentImportKeys.KEY_DOCUMENT_STRATEGY_SNAPSHOT, strategyConfigToMap(strategyConfig));
options.put(DocumentImportKeys.KEY_DOCUMENT_ANALYSIS_SUMMARY, analysis.getFeatures());
options.put(DocumentImportKeys.KEY_DOCUMENT_SOURCE_FILE_EXT, analysis.getSourceFormat());
options.put(DocumentImportKeys.KEY_DOCUMENT_PREVIEW_VERSION, "v1");
document.setOptions(options);
return document;
}
private List<DocumentChunk> buildDocumentChunks(FlexIDKeyGenerator flexIDKeyGenerator,
Document document,
List<RagChunk> previewChunks) {
List<DocumentChunk> chunks = new ArrayList<>();
for (int i = 0; i < previewChunks.size(); i++) {
RagChunk previewChunk = previewChunks.get(i);
DocumentChunk chunk = new DocumentChunk();
chunk.setId(new BigInteger(String.valueOf(flexIDKeyGenerator.generate(chunk, null))));
chunk.setDocumentId(document.getId());
chunk.setDocumentCollectionId(document.getCollectionId());
chunk.setContent(previewChunk.getContent());
chunk.setSorting(i + 1);
Map<String, Object> options = new LinkedHashMap<>(previewChunk.getOptions());
options.put("chunkType", previewChunk.getChunkType());
options.put("sourceLabel", previewChunk.getSourceLabel());
options.put("headingPath", previewChunk.getHeadingPath());
options.put("charCount", previewChunk.getCharCount());
options.put("tokenEstimate", previewChunk.getTokenEstimate());
options.put("qaQuestion", previewChunk.getQuestion());
options.put("qaAnswer", previewChunk.getAnswer());
options.put("partNo", previewChunk.getPartNo());
options.put("partTotal", previewChunk.getPartTotal());
options.put("warnings", previewChunk.getWarnings());
chunk.setOptions(options);
chunks.add(chunk);
}
return chunks;
}
private AnalysisResult analyzeSingleFile(String filePath, String fileName) {
String fileExt = normalizeFileExtension(fileName, filePath);
assertSupportedImportFile(fileExt);
String content = readFileContent(filePath, fileName);
return ragIngestionService.analyze(content, fileExt);
}
private String readFileContent(String filePath, String fileName) {
try (InputStream inputStream = storageService.readStream(filePath)) {
return File2TextUtil.readFromStream(inputStream, fileName, null);
} catch (IOException e) {
Log.error("读取导入文件失败: filePath={}, fileName={}", filePath, fileName, e);
throw new BusinessException("文件解析失败:" + e.getMessage());
}
}
private void assertSupportedImportFile(String fileExt) {
if (!Arrays.asList("pdf", "docx", "txt", "md").contains(fileExt)) {
throw new BusinessException("当前仅支持 pdf/docx/txt/md 文档导入");
}
}
private String normalizeFileExtension(String fileName, String filePath) {
String target = StringUtil.hasText(fileName) ? fileName : filePath;
String ext = FileUtil.getFileTypeByExtension(target);
return ext == null ? "" : ext.toLowerCase(Locale.ROOT);
}
private DocumentCollection assertDocumentCollection(BigInteger knowledgeId) {
DocumentCollection knowledge = knowledgeService.getById(knowledgeId);
if (knowledge == null) {
throw new BusinessException("知识库不存在");
}
if (knowledge.isFaqCollection()) {
throw new BusinessException("FAQ知识库不支持文档上传");
}
return knowledge;
}
private StrategyConfig resolveStrategyConfig(DocumentCollection knowledge,
StrategyConfig requestConfig,
AnalysisResult analysisResult) {
Map<String, Object> options = knowledge.getOptions() == null
? Collections.emptyMap()
: knowledge.getOptions();
String recommended = analysisResult.getRecommendedStrategyCode();
String defaultStrategyCode = asString(options.get(DocumentImportKeys.KEY_SPLITTER_DEFAULT_STRATEGY));
String fallbackStrategyCode = asString(options.get(DocumentImportKeys.KEY_SPLITTER_FALLBACK_STRATEGY));
Boolean autoRecommendEnabled = asBoolean(options.get(DocumentImportKeys.KEY_SPLITTER_AUTO_RECOMMEND_ENABLED), true);
StrategyConfig config = readProfileConfig(options, defaultStrategyCode);
if (config == null) {
config = StrategyConfig.defaults();
}
String requestedStrategyCode = requestConfig == null ? null : requestConfig.getStrategyCode();
String strategyCode = StringUtil.hasText(requestedStrategyCode)
? requestedStrategyCode
: config.getStrategyCode();
if (!StringUtil.hasText(strategyCode) || RagStrategyCodes.AUTO.equals(strategyCode)) {
strategyCode = Boolean.TRUE.equals(autoRecommendEnabled)
? recommended
: (StringUtil.hasText(defaultStrategyCode) ? defaultStrategyCode : recommended);
}
if (!StringUtil.hasText(strategyCode)) {
strategyCode = StringUtil.hasText(fallbackStrategyCode)
? fallbackStrategyCode
: RagStrategyCodes.PARAGRAPH_LENGTH;
}
StrategyConfig profileConfig = readProfileConfig(options, strategyCode);
if (profileConfig != null) {
mergeStrategyConfig(config, profileConfig);
}
if (requestConfig != null) {
mergeStrategyConfig(config, requestConfig);
}
config.setStrategyCode(strategyCode);
if (config.getChunkSize() == null || config.getChunkSize() <= 0) {
config.setChunkSize(RagDefaults.CHUNK_SIZE);
}
if (config.getOverlapSize() == null || config.getOverlapSize() < 0) {
config.setOverlapSize(RagDefaults.OVERLAP_SIZE);
}
if (config.getMdSplitterLevel() == null || config.getMdSplitterLevel() <= 0) {
config.setMdSplitterLevel(RagDefaults.MD_SPLITTER_LEVEL);
}
return config;
}
@SuppressWarnings("unchecked")
private StrategyConfig readProfileConfig(Map<String, Object> options, String strategyCode) {
if (!StringUtil.hasText(strategyCode)) {
return null;
}
Object profileObject = options.get(DocumentImportKeys.KEY_SPLITTER_STRATEGY_PROFILES);
if (!(profileObject instanceof Map)) {
return null;
}
Object strategyObject = ((Map<String, Object>) profileObject).get(strategyCode);
if (!(strategyObject instanceof Map)) {
return null;
}
Map<String, Object> rawProfile = (Map<String, Object>) strategyObject;
StrategyConfig config = StrategyConfig.defaults();
config.setStrategyCode(strategyCode);
config.setChunkSize(asInteger(rawProfile.get("chunkSize"), config.getChunkSize()));
config.setOverlapSize(asInteger(rawProfile.get("overlapSize"), config.getOverlapSize()));
config.setRegex(asString(rawProfile.get("regex")));
config.setRowsPerChunk(asInteger(rawProfile.get("rowsPerChunk"), config.getRowsPerChunk()));
config.setMdSplitterLevel(asInteger(rawProfile.get("mdSplitterLevel"), config.getMdSplitterLevel()));
return config;
}
private void mergeStrategyConfig(StrategyConfig target, StrategyConfig source) {
if (source == null) {
return;
}
if (StringUtil.hasText(source.getStrategyCode())) {
target.setStrategyCode(source.getStrategyCode());
}
if (source.getChunkSize() != null) {
target.setChunkSize(source.getChunkSize());
}
if (source.getOverlapSize() != null) {
target.setOverlapSize(source.getOverlapSize());
}
if (StringUtil.hasText(source.getRegex())) {
target.setRegex(source.getRegex());
}
if (source.getRowsPerChunk() != null) {
target.setRowsPerChunk(source.getRowsPerChunk());
}
if (source.getMdSplitterLevel() != null) {
target.setMdSplitterLevel(source.getMdSplitterLevel());
}
}
private Map<String, Object> strategyConfigToMap(StrategyConfig strategyConfig) {
Map<String, Object> map = new LinkedHashMap<>();
map.put("strategyCode", strategyConfig.getStrategyCode());
map.put("chunkSize", strategyConfig.getChunkSize());
map.put("overlapSize", strategyConfig.getOverlapSize());
map.put("regex", strategyConfig.getRegex());
map.put("rowsPerChunk", strategyConfig.getRowsPerChunk());
map.put("mdSplitterLevel", strategyConfig.getMdSplitterLevel());
return map;
}
private int countWarnings(List<RagChunk> chunks) {
int total = 0;
for (RagChunk chunk : chunks) {
total += chunk.getWarnings() == null ? 0 : chunk.getWarnings().size();
}
return total;
}
private StoreExecutionContext prepareStoreContext(Document entity) {
DocumentCollection knowledge = knowledgeService.getById(entity.getCollectionId()); DocumentCollection knowledge = knowledgeService.getById(entity.getCollectionId());
if (knowledge == null) { if (knowledge == null) {
throw new BusinessException("知识库不存在"); throw new BusinessException("知识库不存在");
@@ -274,23 +664,22 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
if (knowledge.isFaqCollection()) { if (knowledge.isFaqCollection()) {
throw new BusinessException("FAQ知识库不支持文档上传"); throw new BusinessException("FAQ知识库不支持文档上传");
} }
DocumentStore documentStore = null;
DocumentStore documentStore;
try { try {
documentStore = knowledge.toDocumentStore(); documentStore = knowledge.toDocumentStore();
} catch (Exception e) { } catch (Exception e) {
Log.error(e.getMessage()); Log.error("向量库配置错误: knowledgeId={}", knowledge.getId(), e);
throw new BusinessException("向量数据库配置错误"); throw new BusinessException("向量数据库配置错误");
} }
if (documentStore == null) { if (documentStore == null) {
throw new BusinessException("向量数据库配置错误"); throw new BusinessException("向量数据库配置错误");
} }
// 设置向量模型
Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId()); Model model = modelService.getModelInstance(knowledge.getVectorEmbedModelId());
if (model == null) { if (model == null) {
throw new BusinessException("该知识库未配置大模型"); throw new BusinessException("该知识库未配置大模型");
} }
// 设置向量模型
EmbeddingModel embeddingModel = model.toEmbeddingModel(); EmbeddingModel embeddingModel = model.toEmbeddingModel();
documentStore.setEmbeddingModel(embeddingModel); documentStore.setEmbeddingModel(embeddingModel);
@@ -300,46 +689,152 @@ public class DocumentServiceImpl extends ServiceImpl<DocumentMapper, Document> i
embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel()); embeddingOptions.setDimensions(knowledge.getDimensionOfVectorModel());
options.setEmbeddingOptions(embeddingOptions); options.setEmbeddingOptions(embeddingOptions);
options.setIndexName(options.getCollectionName()); options.setIndexName(options.getCollectionName());
DocumentSearcher searcher = null;
if (knowledge.isSearchEngineEnabled()) {
searcher = searcherFactory.getSearcher((String) knowledge.getOptionsByKey(KEY_SEARCH_ENGINE_TYPE));
}
return new StoreExecutionContext(knowledge, model, embeddingModel, documentStore, options, searcher);
}
private void storeDocumentChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
List<com.easyagents.core.document.Document> documents = new ArrayList<>(); List<com.easyagents.core.document.Document> documents = new ArrayList<>();
documentChunks.forEach(item -> { for (DocumentChunk item : documentChunks) {
com.easyagents.core.document.Document document = new com.easyagents.core.document.Document(); com.easyagents.core.document.Document document = new com.easyagents.core.document.Document();
document.setId(item.getId()); document.setId(item.getId());
document.setContent(item.getContent()); document.setContent(item.getContent());
documents.add(document); documents.add(document);
} }
);
StoreResult result = null; StoreResult result;
try { try {
result = documentStore.store(documents, options); result = storeContext.documentStore.store(documents, storeContext.options);
} catch (Exception e) { } catch (Exception e) {
Log.error("Vector store failed: knowledgeId={}, collection={}, chunkCount={}", Log.error("Vector store failed: knowledgeId={}, collection={}, chunkCount={}",
knowledge.getId(), options.getCollectionName(), documents.size(), e); storeContext.knowledge.getId(),
storeContext.options.getCollectionName(),
documents.size(),
e);
throw new BusinessException("向量过程中发生错误,错误信息为:" + e.getMessage()); throw new BusinessException("向量过程中发生错误,错误信息为:" + e.getMessage());
} }
if (result == null || !result.isSuccess()) { if (result == null || !result.isSuccess()) {
Log.error("DocumentStore.store failed: " + result); Log.error("DocumentStore.store failed: {}", result);
throw new BusinessException("DocumentStore.store failed"); throw new BusinessException("DocumentStore.store failed");
} }
if (knowledge.isSearchEngineEnabled()) { if (storeContext.searcher != null) {
// 获取搜索引擎 for (com.easyagents.core.document.Document document : documents) {
DocumentSearcher searcher = searcherFactory.getSearcher((String) knowledge.getOptionsByKey(KEY_SEARCH_ENGINE_TYPE)); storeContext.searcher.addDocument(document);
// 添加到搜索引擎 }
documents.forEach(searcher::addDocument);
} }
}
private void rollbackStoredChunks(StoreExecutionContext storeContext, List<DocumentChunk> documentChunks) {
try {
List<BigInteger> chunkIds = new ArrayList<>();
for (DocumentChunk chunk : documentChunks) {
chunkIds.add(chunk.getId());
}
storeContext.documentStore.delete(chunkIds, storeContext.options);
if (storeContext.searcher != null) {
for (BigInteger chunkId : chunkIds) {
storeContext.searcher.deleteDocument(chunkId);
}
}
} catch (Exception e) {
Log.error("回滚向量文档失败: knowledgeId={}", storeContext.knowledge.getId(), e);
}
}
private void updateKnowledgeAfterStore(StoreExecutionContext storeContext) {
DocumentCollection documentCollection = new DocumentCollection(); DocumentCollection documentCollection = new DocumentCollection();
documentCollection.setId(entity.getCollectionId()); documentCollection.setId(storeContext.knowledge.getId());
Map<String, Object> knowledgeOptions = knowledge.getOptions(); Map<String, Object> knowledgeOptions = storeContext.knowledge.getOptions() == null
? new HashMap<>()
: new HashMap<>(storeContext.knowledge.getOptions());
knowledgeOptions.put(KEY_CAN_UPDATE_EMBEDDING_MODEL, false); knowledgeOptions.put(KEY_CAN_UPDATE_EMBEDDING_MODEL, false);
documentCollection.setOptions(knowledgeOptions); documentCollection.setOptions(knowledgeOptions);
knowledgeService.updateById(documentCollection); knowledgeService.updateById(documentCollection);
if (knowledge.getDimensionOfVectorModel() == null) {
int dimension = Model.getEmbeddingDimension(embeddingModel); if (storeContext.knowledge.getDimensionOfVectorModel() == null) {
knowledge.setDimensionOfVectorModel(dimension); int dimension = Model.getEmbeddingDimension(storeContext.embeddingModel);
knowledgeService.updateById(knowledge); DocumentCollection update = new DocumentCollection();
update.setId(storeContext.knowledge.getId());
update.setDimensionOfVectorModel(dimension);
knowledgeService.updateById(update);
}
}
private void persistDocumentWithChunks(Document document, List<DocumentChunk> chunks) {
this.getMapper().insert(document);
AtomicInteger sort = new AtomicInteger(1);
for (DocumentChunk item : chunks) {
item.setDocumentCollectionId(document.getCollectionId());
item.setDocumentId(document.getId());
item.setSorting(sort.getAndIncrement());
documentChunkService.save(item);
}
}
private void cleanupPersistedDocument(Document document) {
if (document == null || document.getId() == null) {
return;
}
documentChunkMapper.deleteByQuery(QueryWrapper.create().eq(DocumentChunk::getDocumentId, document.getId()));
this.getMapper().deleteById(document.getId());
}
private String asString(Object value) {
return value == null ? null : String.valueOf(value);
}
private Integer asInteger(Object value, Integer defaultValue) {
if (value == null) {
return defaultValue;
}
if (value instanceof Number) {
return ((Number) value).intValue();
}
if (value instanceof String && StringUtil.hasText((String) value)) {
return Integer.parseInt((String) value);
}
return defaultValue;
}
private Boolean asBoolean(Object value, boolean defaultValue) {
if (value == null) {
return defaultValue;
}
if (value instanceof Boolean) {
return (Boolean) value;
}
if (value instanceof Number) {
return ((Number) value).intValue() != 0;
}
return Boolean.parseBoolean(String.valueOf(value));
}
private static class StoreExecutionContext {
private final DocumentCollection knowledge;
private final Model model;
private final EmbeddingModel embeddingModel;
private final DocumentStore documentStore;
private final StoreOptions options;
private final DocumentSearcher searcher;
private StoreExecutionContext(DocumentCollection knowledge,
Model model,
EmbeddingModel embeddingModel,
DocumentStore documentStore,
StoreOptions options,
DocumentSearcher searcher) {
this.knowledge = knowledge;
this.model = model;
this.embeddingModel = embeddingModel;
this.documentStore = documentStore;
this.options = options;
this.searcher = searcher;
} }
return true;
} }
public DocumentSplitter getDocumentSplitter(DocumentCollectionSplitParams params) { public DocumentSplitter getDocumentSplitter(DocumentCollectionSplitParams params) {

View File

@@ -10,17 +10,22 @@ import com.easyagents.core.model.embedding.EmbeddingModel;
import com.easyagents.core.model.rerank.RerankModel; import com.easyagents.core.model.rerank.RerankModel;
import com.easyagents.core.store.VectorData; import com.easyagents.core.store.VectorData;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.core.util.StringUtil;
import com.mybatisflex.spring.service.impl.ServiceImpl; import com.mybatisflex.spring.service.impl.ServiceImpl;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import tech.easyflow.ai.entity.Model; import tech.easyflow.ai.entity.Model;
import tech.easyflow.ai.entity.ModelProvider; import tech.easyflow.ai.entity.ModelProvider;
import tech.easyflow.ai.mapper.ModelMapper; import tech.easyflow.ai.mapper.ModelMapper;
import tech.easyflow.ai.service.ModelProviderService; import tech.easyflow.ai.service.ModelProviderService;
import tech.easyflow.ai.service.ModelService; import tech.easyflow.ai.service.ModelService;
import tech.easyflow.common.tree.Tree;
import tech.easyflow.common.util.SqlOperatorsUtil;
import tech.easyflow.common.util.SqlUtil;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import javax.annotation.Resource; import javax.annotation.Resource;
@@ -182,6 +187,152 @@ public class ModelServiceImpl extends ServiceImpl<ModelMapper, Model> implements
throw new BusinessException("模型ID不能为空"); throw new BusinessException("模型ID不能为空");
} }
Model model = modelMapper.selectOneWithRelationsById(modelId); Model model = modelMapper.selectOneWithRelationsById(modelId);
return fillProviderDefaults(model);
}
@Override
public Model getModelInstanceByInvokeCode(String invokeCode) {
if (StrUtil.isBlank(invokeCode)) {
throw new BusinessException("invokeCode不能为空");
}
QueryWrapper queryWrapper = QueryWrapper.create().eq(Model::getInvokeCode, invokeCode.trim());
Model model = modelMapper.selectOneWithRelationsByQuery(queryWrapper);
return fillProviderDefaults(model);
}
@Override
public void validateForSaveOrUpdate(Model entity, boolean isSave) {
if (entity == null) {
throw new BusinessException("模型配置不能为空");
}
if (entity.getPublishEnabled() == null) {
entity.setPublishEnabled(Boolean.FALSE);
}
String originalInvokeCode = StrUtil.trim(entity.getInvokeCode());
String invokeCode = originalInvokeCode;
boolean autoGeneratedInvokeCode = StrUtil.isBlank(invokeCode);
if (autoGeneratedInvokeCode) {
invokeCode = buildDefaultInvokeCode(entity.getModelName());
}
if (Boolean.TRUE.equals(entity.getPublishEnabled())) {
if (StrUtil.isBlank(invokeCode)) {
throw new BusinessException("开启 API 调用前,请先配置 invokeCode");
}
if (!Model.MODEL_TYPES[0].equals(entity.getModelType())) {
throw new BusinessException("只有聊天模型支持开启 API 调用");
}
}
if (StrUtil.isBlank(invokeCode)) {
entity.setInvokeCode(null);
return;
}
QueryWrapper queryWrapper = QueryWrapper.create().eq(Model::getInvokeCode, invokeCode);
if (!isSave && entity.getId() != null) {
queryWrapper.ne(Model::getId, entity.getId());
}
boolean duplicated = modelMapper.selectCountByQuery(queryWrapper) > 0;
if (duplicated && autoGeneratedInvokeCode && !Boolean.TRUE.equals(entity.getPublishEnabled())) {
entity.setInvokeCode(null);
return;
}
if (duplicated) {
throw new BusinessException("invokeCode 已存在,请更换后重试");
}
entity.setInvokeCode(invokeCode);
}
@Override
public List<Model> listInvokeModels() {
QueryWrapper queryWrapper = QueryWrapper.create().eq(Model::getModelType, Model.MODEL_TYPES[0]);
return modelMapper.selectListWithRelationsByQuery(queryWrapper).stream()
.map(this::fillProviderDefaults)
.collect(Collectors.toList());
}
@Override
public List<Model> listSelectableModels(Model entity, Boolean asTree, String sortKey, String sortType) {
QueryWrapper queryWrapper = QueryWrapper.create(
entity,
entity == null ? com.mybatisflex.core.query.SqlOperators.empty() : SqlOperatorsUtil.build(entity.getClass())
);
queryWrapper.orderBy(buildOrderBy(sortKey, sortType));
List<Model> list = Tree.tryToTree(modelMapper.selectListWithRelationsByQuery(queryWrapper), asTree);
list.forEach(this::decorateModelTitle);
return list;
}
@Override
public Model updateInvokeConfig(BigInteger id, String invokeCode, Boolean publishEnabled) {
Model existing = getModelInstance(id);
if (existing == null) {
throw new BusinessException("模型不存在");
}
if (!Model.MODEL_TYPES[0].equals(existing.getModelType())) {
throw new BusinessException("只有聊天模型支持统一网关配置");
}
existing.setInvokeCode(invokeCode);
existing.setPublishEnabled(Boolean.TRUE.equals(publishEnabled));
validateForSaveOrUpdate(existing, false);
modelMapper.update(existing);
return getModelInstance(id);
}
@Override
@Transactional(rollbackFor = Exception.class)
public List<Model> batchUpdateInvokePublishStatus(List<BigInteger> ids, Boolean publishEnabled) {
if (CollectionUtils.isEmpty(ids)) {
throw new BusinessException("请选择要操作的模型");
}
List<BigInteger> uniqueIds = new ArrayList<>(new LinkedHashSet<>(ids));
List<Model> updatedModels = new ArrayList<>(uniqueIds.size());
for (BigInteger id : uniqueIds) {
updatedModels.add(updateInvokeConfig(id, null, publishEnabled));
}
return updatedModels;
}
private void decorateModelTitle(Model model) {
if (model == null) {
return;
}
String providerName = Optional.ofNullable(model.getModelProvider())
.map(ModelProvider::getProviderName)
.orElse("-");
model.setTitle(providerName + "/" + model.getTitle());
}
private String buildOrderBy(String sortKey, String sortType) {
sortKey = StringUtil.camelToUnderline(sortKey);
return SqlUtil.buildOrderBy(sortKey, sortType, "id desc");
}
private String buildDefaultInvokeCode(String modelName) {
if (StrUtil.isBlank(modelName)) {
return null;
}
String normalized = modelName.trim()
.replaceAll("[^A-Za-z0-9._:-]+", "-")
.replaceAll("-{2,}", "-")
.replaceFirst("^[^A-Za-z0-9]+", "");
if (StrUtil.isBlank(normalized)) {
return null;
}
if (normalized.length() > 128) {
normalized = normalized.substring(0, 128);
}
if (normalized.length() == 1) {
normalized = normalized + "-model";
}
return normalized;
}
private Model fillProviderDefaults(Model model) {
if (model == null) { if (model == null) {
return null; return null;
} }

View File

@@ -2,18 +2,23 @@ package tech.easyflow.ai.service.impl;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.spring.service.impl.ServiceImpl; import com.mybatisflex.spring.service.impl.ServiceImpl;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import tech.easyflow.ai.entity.PluginCategory; import tech.easyflow.ai.entity.PluginCategory;
import tech.easyflow.ai.entity.PluginCategoryMapping; import tech.easyflow.ai.entity.PluginCategoryMapping;
import tech.easyflow.ai.mapper.PluginCategoryMapper; import tech.easyflow.ai.mapper.PluginCategoryMapper;
import tech.easyflow.ai.mapper.PluginCategoryMappingMapper; import tech.easyflow.ai.mapper.PluginCategoryMappingMapper;
import tech.easyflow.ai.service.PluginCategoryMappingService; import tech.easyflow.ai.service.PluginCategoryMappingService;
import org.springframework.stereotype.Service;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/** /**
* 服务层实现。 * 服务层实现。
@@ -31,45 +36,46 @@ public class PluginCategoryMappingServiceImpl extends ServiceImpl<PluginCategory
private PluginCategoryMapper pluginCategoryMapper; private PluginCategoryMapper pluginCategoryMapper;
@Override @Override
public boolean updateRelation(BigInteger pluginId, ArrayList<BigInteger> categoryIds) { @Transactional
if (categoryIds == null){ public boolean updateRelation(BigInteger pluginId, List<BigInteger> categoryIds) {
QueryWrapper queryWrapper = QueryWrapper.create().select("*") if (pluginId == null) {
.from("tb_plugin_category_mapping") throw new BusinessException("插件ID不能为空");
.where("plugin_id = ?", pluginId);
int delete = relationMapper.deleteByQuery(queryWrapper);
if (delete <= 0){
throw new BusinessException("删除失败");
}
return true;
} }
for (BigInteger categoryId : categoryIds) {
QueryWrapper queryWrapper = QueryWrapper.create().select("*") List<BigInteger> targetCategoryIds = categoryIds == null
? Collections.emptyList()
: categoryIds.stream()
.filter(java.util.Objects::nonNull)
.distinct()
.collect(Collectors.toList());
QueryWrapper currentRelationQuery = QueryWrapper.create().select("category_id")
.from("tb_plugin_category_mapping")
.where("plugin_id = ?", pluginId);
List<BigInteger> currentCategoryIds = relationMapper.selectListByQueryAs(currentRelationQuery, BigInteger.class);
Set<BigInteger> currentCategoryIdSet = new LinkedHashSet<>(currentCategoryIds);
Set<BigInteger> targetCategoryIdSet = new LinkedHashSet<>(targetCategoryIds);
Set<BigInteger> categoryIdsToDelete = new LinkedHashSet<>(currentCategoryIdSet);
categoryIdsToDelete.removeAll(targetCategoryIdSet);
if (!categoryIdsToDelete.isEmpty()) {
QueryWrapper deleteQuery = QueryWrapper.create()
.from("tb_plugin_category_mapping") .from("tb_plugin_category_mapping")
.where("plugin_id = ?", pluginId) .where("plugin_id = ?", pluginId)
.where("category_id = ?", categoryId); .in("category_id", new ArrayList<>(categoryIdsToDelete));
PluginCategoryMapping selectedOneByQuery = relationMapper.selectOneByQuery(queryWrapper); relationMapper.deleteByQuery(deleteQuery);
}
Set<BigInteger> categoryIdsToInsert = new LinkedHashSet<>(targetCategoryIdSet);
categoryIdsToInsert.removeAll(currentCategoryIdSet);
for (BigInteger categoryId : categoryIdsToInsert) {
PluginCategoryMapping pluginCategoryMapping = new PluginCategoryMapping(); PluginCategoryMapping pluginCategoryMapping = new PluginCategoryMapping();
pluginCategoryMapping.setCategoryId(categoryId); pluginCategoryMapping.setCategoryId(categoryId);
pluginCategoryMapping.setPluginId(pluginId); pluginCategoryMapping.setPluginId(pluginId);
if (selectedOneByQuery == null) { int insert = relationMapper.insert(pluginCategoryMapping);
int insert = relationMapper.insert(pluginCategoryMapping); if (insert <= 0) {
if (insert <= 0) { throw new BusinessException("新增分类关联失败");
throw new BusinessException("新增失败");
}
} else {
QueryWrapper queryWrapperUpdate = QueryWrapper.create().select("*")
.from("tb_plugin_category_mapping")
.where("plugin_id = ?", pluginId);
PluginCategoryMapping selectedOne = relationMapper.selectOneByQuery(queryWrapper);
if (selectedOne != null){
continue;
}
int update = relationMapper.updateByQuery(pluginCategoryMapping, queryWrapperUpdate);
if (update <= 0){
throw new BusinessException("更新失败");
}
} }
} }
return true; return true;
} }

View File

@@ -1,29 +1,39 @@
package tech.easyflow.ai.service.impl; package tech.easyflow.ai.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import com.mybatisflex.core.paginate.Page; import com.mybatisflex.core.paginate.Page;
import com.mybatisflex.core.query.QueryWrapper; import com.mybatisflex.core.query.QueryWrapper;
import com.mybatisflex.spring.service.impl.ServiceImpl; import com.mybatisflex.spring.service.impl.ServiceImpl;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import tech.easyflow.ai.entity.*; import tech.easyflow.ai.entity.*;
import tech.easyflow.ai.mapper.PluginCategoryMappingMapper; import tech.easyflow.ai.mapper.PluginCategoryMappingMapper;
import tech.easyflow.ai.mapper.PluginMapper; import tech.easyflow.ai.mapper.PluginMapper;
import tech.easyflow.ai.service.BotPluginService; import tech.easyflow.ai.service.BotPluginService;
import tech.easyflow.ai.service.PluginService;
import org.springframework.stereotype.Service;
import tech.easyflow.ai.service.PluginItemService; import tech.easyflow.ai.service.PluginItemService;
import tech.easyflow.ai.service.PluginService;
import tech.easyflow.ai.service.PluginVisibilityService;
import tech.easyflow.common.domain.Result; import tech.easyflow.common.domain.Result;
import tech.easyflow.common.web.exceptions.BusinessException; import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource; import javax.annotation.Resource;
import java.math.BigInteger; import java.math.BigInteger;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static tech.easyflow.ai.entity.table.PluginTableDef.PLUGIN;
/** /**
* 服务层实现。 * 服务层实现。
* *
@@ -46,15 +56,19 @@ public class PluginServiceImpl extends ServiceImpl<PluginMapper, Plugin> impleme
@Resource @Resource
private PluginItemService pluginItemService; private PluginItemService pluginItemService;
@Resource
private CategoryPermissionService categoryPermissionService;
@Resource
private PluginVisibilityService pluginVisibilityService;
@Override @Override
public boolean savePlugin(Plugin plugin) { public Plugin savePlugin(Plugin plugin) {
plugin.setCreated(new Date()); plugin.setCreated(new Date());
int insert = pluginMapper.insert(plugin); int insert = pluginMapper.insert(plugin);
if (insert <= 0) { if (insert <= 0) {
throw new BusinessException("保存失败"); throw new BusinessException("保存失败");
} }
return true; return plugin;
} }
@Override @Override
@@ -104,22 +118,39 @@ public class PluginServiceImpl extends ServiceImpl<PluginMapper, Plugin> impleme
@Override @Override
public Result<Page<Plugin>> pageByCategory(Long pageNumber, Long pageSize, int category) { public Result<Page<Plugin>> pageByCategory(Long pageNumber, Long pageSize, int category) {
// 通过分类查询插件 RoleCategoryAccessSnapshot access = categoryPermissionService.getCurrentAccess("PLUGIN");
QueryWrapper queryWrapper = QueryWrapper.create().select(PluginCategoryMapping::getPluginId) QueryWrapper queryWrapper = QueryWrapper.create().select(PluginCategoryMapping::getPluginId)
.eq(PluginCategoryMapping::getCategoryId, category); .eq(PluginCategoryMapping::getCategoryId, category);
// 分页查询该分类中的插件 List<BigInteger> allCategoryPluginIds = pluginCategoryMappingMapper.selectListByQueryAs(queryWrapper, BigInteger.class)
Page<BigInteger> pagePluginIds = pluginCategoryMappingMapper.paginateAs(new Page<>(pageNumber, pageSize), queryWrapper, BigInteger.class); .stream()
Page<PluginCategoryMapping> paginateCategories = pluginCategoryMappingMapper.paginate(pageNumber, pageSize, queryWrapper); .filter(item -> item != null)
List<Plugin> plugins = Collections.emptyList(); .collect(Collectors.toCollection(ArrayList::new));
if (paginateCategories.getRecords().isEmpty()) { if (CollectionUtil.isEmpty(allCategoryPluginIds)) {
return Result.ok(new Page<>(plugins, pageNumber, pageSize, paginateCategories.getTotalRow())); return Result.ok(new Page<>(Collections.emptyList(), pageNumber, pageSize, 0L));
} }
List<BigInteger> pluginIds = pagePluginIds.getRecords();
// 查询对应的插件信息 List<BigInteger> orderedCategoryPluginIds = new ArrayList<>(new LinkedHashSet<>(allCategoryPluginIds));
QueryWrapper queryPluginWrapper = QueryWrapper.create().select() List<BigInteger> visiblePluginIds = orderedCategoryPluginIds;
.in(Plugin::getId, pluginIds); if (access.isRestricted()) {
plugins = pluginMapper.selectListByQuery(queryPluginWrapper); Set<BigInteger> visiblePluginIdSet = new LinkedHashSet<>(pluginVisibilityService.getCurrentVisiblePluginIds());
Page<Plugin> aiPluginPage = new Page<>(plugins, pageNumber, pageSize, paginateCategories.getTotalRow()); List<BigInteger> creatorPluginIds = queryCreatorPluginIds(orderedCategoryPluginIds, access.getAccountIdAsLong());
LinkedHashSet<BigInteger> mergedVisibleIds = orderedCategoryPluginIds.stream()
.filter(pluginId -> visiblePluginIdSet.contains(pluginId) || creatorPluginIds.contains(pluginId))
.collect(Collectors.toCollection(LinkedHashSet::new));
visiblePluginIds = new ArrayList<>(mergedVisibleIds);
}
if (visiblePluginIds.isEmpty()) {
return Result.ok(new Page<>(Collections.emptyList(), pageNumber, pageSize, 0L));
}
int fromIndex = Math.max(0, Math.toIntExact((pageNumber - 1) * pageSize));
if (fromIndex >= visiblePluginIds.size()) {
return Result.ok(new Page<>(Collections.emptyList(), pageNumber, pageSize, visiblePluginIds.size()));
}
int toIndex = Math.min(visiblePluginIds.size(), Math.toIntExact(fromIndex + pageSize));
List<BigInteger> currentPagePluginIds = new ArrayList<>(visiblePluginIds.subList(fromIndex, toIndex));
List<Plugin> plugins = queryPluginsByIds(currentPagePluginIds);
Page<Plugin> aiPluginPage = new Page<>(plugins, pageNumber, pageSize, visiblePluginIds.size());
return Result.ok(aiPluginPage); return Result.ok(aiPluginPage);
} }
@@ -129,5 +160,37 @@ public class PluginServiceImpl extends ServiceImpl<PluginMapper, Plugin> impleme
return true; return true;
} }
private List<BigInteger> queryCreatorPluginIds(List<BigInteger> pluginIds, Long creatorId) {
if (CollectionUtil.isEmpty(pluginIds) || creatorId == null) {
return Collections.emptyList();
}
QueryWrapper creatorPluginWrapper = QueryWrapper.create().select(Plugin::getId)
.in(Plugin::getId, pluginIds)
.eq(Plugin::getCreatedBy, creatorId);
return pluginMapper.selectListByQueryAs(creatorPluginWrapper, BigInteger.class);
}
private List<Plugin> queryPluginsByIds(List<BigInteger> pluginIds) {
if (CollectionUtil.isEmpty(pluginIds)) {
return Collections.emptyList();
}
QueryWrapper queryPluginWrapper = QueryWrapper.create().select().in(Plugin::getId, pluginIds);
List<Plugin> plugins = pluginMapper.selectListByQuery(queryPluginWrapper);
Map<BigInteger, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(
Plugin::getId,
item -> item,
(left, right) -> left,
LinkedHashMap::new
));
List<Plugin> orderedPlugins = new ArrayList<>();
for (BigInteger pluginId : pluginIds) {
Plugin plugin = pluginMap.get(pluginId);
if (plugin != null) {
orderedPlugins.add(plugin);
}
}
return orderedPlugins;
}
} }

View File

@@ -0,0 +1,69 @@
package tech.easyflow.ai.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import com.mybatisflex.core.query.QueryWrapper;
import org.springframework.stereotype.Service;
import tech.easyflow.ai.entity.PluginCategoryMapping;
import tech.easyflow.ai.mapper.PluginCategoryMappingMapper;
import tech.easyflow.ai.service.PluginVisibilityService;
import tech.easyflow.common.web.exceptions.BusinessException;
import tech.easyflow.system.entity.vo.RoleCategoryAccessSnapshot;
import tech.easyflow.system.service.CategoryPermissionService;
import javax.annotation.Resource;
import java.math.BigInteger;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@Service
public class PluginVisibilityServiceImpl implements PluginVisibilityService {
@Resource
private CategoryPermissionService categoryPermissionService;
@Resource
private PluginCategoryMappingMapper pluginCategoryMappingMapper;
@Override
public Set<BigInteger> getCurrentVisiblePluginIds() {
RoleCategoryAccessSnapshot snapshot = categoryPermissionService.getCurrentAccess("PLUGIN");
if (!snapshot.isRestricted() || CollectionUtil.isEmpty(snapshot.getCategoryIds())) {
return Collections.emptySet();
}
QueryWrapper mappingWrapper = QueryWrapper.create()
.select(PluginCategoryMapping::getPluginId)
.in(PluginCategoryMapping::getCategoryId, snapshot.getCategoryIds());
List<BigInteger> pluginIds = pluginCategoryMappingMapper.selectListByQueryAs(mappingWrapper, BigInteger.class);
return new LinkedHashSet<>(pluginIds);
}
@Override
public boolean canAccessPlugin(Long createdBy, BigInteger pluginId) {
RoleCategoryAccessSnapshot snapshot = categoryPermissionService.getCurrentAccess("PLUGIN");
if (!snapshot.isRestricted()) {
return true;
}
if (createdBy != null && snapshot.getAccountId() != null
&& snapshot.getAccountId().equals(BigInteger.valueOf(createdBy))) {
return true;
}
if (CollectionUtil.isEmpty(snapshot.getCategoryIds())) {
return false;
}
QueryWrapper mappingWrapper = QueryWrapper.create()
.select(PluginCategoryMapping::getPluginId)
.eq(PluginCategoryMapping::getPluginId, pluginId)
.in(PluginCategoryMapping::getCategoryId, snapshot.getCategoryIds());
List<BigInteger> pluginIds = pluginCategoryMappingMapper.selectListByQueryAs(mappingWrapper, BigInteger.class);
return CollectionUtil.isNotEmpty(pluginIds);
}
@Override
public void assertPluginVisible(Long createdBy, BigInteger pluginId, String message) {
if (!canAccessPlugin(createdBy, pluginId)) {
throw new BusinessException(message == null ? "无权限访问该资源" : message);
}
}
}

View File

@@ -0,0 +1,192 @@
package tech.easyflow.ai.invoke.mapper;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.Assert;
import org.junit.Test;
import tech.easyflow.ai.invoke.exception.ModelInvokeException;
import tech.easyflow.ai.invoke.model.UnifiedChatRequest;
import tech.easyflow.ai.invoke.model.UnifiedChatResponse;
import tech.easyflow.ai.invoke.model.UnifiedChoice;
import tech.easyflow.ai.invoke.model.UnifiedContentPart;
import tech.easyflow.ai.invoke.model.UnifiedImageUrl;
import tech.easyflow.ai.invoke.model.UnifiedMessage;
import tech.easyflow.ai.invoke.model.UnifiedResponseFormat;
import tech.easyflow.ai.invoke.model.UnifiedTool;
import tech.easyflow.ai.invoke.model.UnifiedToolCall;
import tech.easyflow.ai.invoke.model.UnifiedToolCallFunction;
import tech.easyflow.ai.invoke.model.UnifiedToolFunction;
import tech.easyflow.ai.invoke.model.UnifiedUsage;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionRequest;
import tech.easyflow.ai.invoke.protocol.openai.OpenAiChatCompletionResponse;
import java.util.Collections;
import java.util.List;
public class OpenAiProtocolMapperTest {
private final OpenAiProtocolMapper mapper = new OpenAiProtocolMapper(new ObjectMapper());
@Test
public void shouldParseTextAndImageRequest() {
String rawBody = """
{
"model": "gpt-4-1-prod",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "帮我看下这张图"
},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,AAAA",
"detail": "high"
}
}
]
}
],
"stream": false,
"temperature": 0.2,
"top_p": 0.8,
"max_tokens": 512,
"seed": 7,
"tools": [
{
"type": "function",
"function": {
"name": "query_weather",
"description": "query weather",
"parameters": {
"type": "object"
}
}
}
],
"tool_choice": "auto",
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "weather_schema"
}
}
}
""";
OpenAiChatCompletionRequest request = mapper.readRequest(rawBody);
UnifiedChatRequest unifiedRequest = mapper.toUnifiedRequest(request);
Assert.assertEquals("gpt-4-1-prod", unifiedRequest.getModel());
Assert.assertEquals(Long.valueOf(7), unifiedRequest.getSeed());
Assert.assertEquals(Integer.valueOf(512), unifiedRequest.getMaxTokens());
Assert.assertNotNull(unifiedRequest.getTools());
Assert.assertEquals(1, unifiedRequest.getTools().size());
Assert.assertEquals("query_weather", unifiedRequest.getTools().get(0).getFunction().getName());
Assert.assertEquals("json_schema", unifiedRequest.getResponseFormat().getType());
Assert.assertEquals("weather_schema", unifiedRequest.getResponseFormat().getJsonSchema().get("name").asText());
UnifiedMessage message = unifiedRequest.getMessages().get(0);
Assert.assertEquals("user", message.getRole());
Assert.assertNotNull(message.getContentParts());
Assert.assertEquals(2, message.getContentParts().size());
UnifiedContentPart imagePart = message.getContentParts().get(1);
Assert.assertEquals("image_url", imagePart.getType());
Assert.assertEquals("data:image/png;base64,AAAA", imagePart.getImageUrl().getUrl());
Assert.assertEquals("high", imagePart.getImageUrl().getDetail());
}
@Test
public void shouldRejectUnsupportedRootField() {
String rawBody = """
{
"model": "gpt-4-1-prod",
"messages": [
{
"role": "user",
"content": "hello"
}
],
"n": 2
}
""";
ModelInvokeException exception = Assert.assertThrows(
ModelInvokeException.class,
() -> mapper.readRequest(rawBody)
);
Assert.assertEquals(400, exception.getStatus());
Assert.assertEquals("unsupported_field", exception.getCode());
}
@Test
public void shouldAllowMissingOptionalFields() {
String rawBody = """
{
"model": "deepseek-chat",
"messages": [
{
"role": "user",
"content": "你好,介绍一下你自己。"
}
]
}
""";
OpenAiChatCompletionRequest request = mapper.readRequest(rawBody);
UnifiedChatRequest unifiedRequest = mapper.toUnifiedRequest(request);
Assert.assertEquals("deepseek-chat", unifiedRequest.getModel());
Assert.assertNotNull(unifiedRequest.getMessages());
Assert.assertEquals(1, unifiedRequest.getMessages().size());
Assert.assertNull(unifiedRequest.getTools());
Assert.assertNull(unifiedRequest.getToolChoice());
Assert.assertNull(unifiedRequest.getResponseFormat());
}
@Test
public void shouldMapToolCallsAndUsageInResponse() {
UnifiedChatResponse response = new UnifiedChatResponse();
response.setId("chatcmpl-1");
response.setObject("chat.completion");
response.setCreated(123L);
response.setModel("gpt-4-1-prod");
UnifiedToolCallFunction toolCallFunction = new UnifiedToolCallFunction();
toolCallFunction.setName("query_weather");
toolCallFunction.setArguments("{\"city\":\"shanghai\"}");
UnifiedToolCall toolCall = new UnifiedToolCall();
toolCall.setId("call_1");
toolCall.setType("function");
toolCall.setFunction(toolCallFunction);
UnifiedMessage message = new UnifiedMessage();
message.setRole("assistant");
message.setContent(null);
message.setToolCalls(Collections.singletonList(toolCall));
UnifiedChoice choice = new UnifiedChoice();
choice.setIndex(0);
choice.setMessage(message);
choice.setFinishReason("tool_calls");
UnifiedUsage usage = new UnifiedUsage();
usage.setPromptTokens(12);
usage.setCompletionTokens(34);
usage.setTotalTokens(46);
response.setChoices(List.of(choice));
response.setUsage(usage);
OpenAiChatCompletionResponse openAiResponse = mapper.toOpenAiResponse(response);
Assert.assertEquals("chatcmpl-1", openAiResponse.getId());
Assert.assertEquals("tool_calls", openAiResponse.getChoices().get(0).getFinishReason());
Assert.assertEquals("query_weather", openAiResponse.getChoices().get(0).getMessage().getToolCalls().get(0).getFunction().getName());
Assert.assertEquals(Integer.valueOf(46), openAiResponse.getUsage().getTotalTokens());
}
}

Some files were not shown because too many files have changed in this diff Show More