初始化
This commit is contained in:
@@ -0,0 +1,77 @@
|
||||
package tech.easyflow.ai.easyagents;
|
||||
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
public class CustomMultipartFile implements MultipartFile {
|
||||
|
||||
private final byte[] content;
|
||||
private final String name;
|
||||
private final String originalFilename;
|
||||
private final String contentType;
|
||||
|
||||
public CustomMultipartFile(byte[] content, String name, String originalFilename, String contentType) {
|
||||
this.content = content;
|
||||
this.name = name;
|
||||
this.originalFilename = originalFilename;
|
||||
this.contentType = contentType;
|
||||
}
|
||||
|
||||
// 从 InputStream 构建 CustomMultipartFile
|
||||
public static CustomMultipartFile fromInputStream(InputStream inputStream, String name, String originalFilename, String contentType) throws IOException {
|
||||
// 将 InputStream 转为字节数组
|
||||
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
|
||||
byte[] buffer = new byte[1024];
|
||||
int bytesRead;
|
||||
while ((bytesRead = inputStream.read(buffer)) != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead);
|
||||
}
|
||||
return new CustomMultipartFile(outputStream.toByteArray(), name, originalFilename, contentType);
|
||||
}
|
||||
|
||||
// 实现 MultipartFile 接口的抽象方法
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOriginalFilename() {
|
||||
return originalFilename;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getContentType() {
|
||||
return contentType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return content == null || content.length == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getSize() {
|
||||
return content.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getBytes() throws IOException {
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputStream getInputStream() throws IOException {
|
||||
return new ByteArrayInputStream(content);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void transferTo(java.io.File dest) throws IOException, IllegalStateException {
|
||||
// 如需保存到文件,可实现此方法
|
||||
throw new UnsupportedOperationException("transferTo is not implemented");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package tech.easyflow.ai.easyagents.listener;
|
||||
|
||||
import com.easyagents.core.message.AiMessage;
|
||||
import com.easyagents.core.message.ToolMessage;
|
||||
import com.easyagents.core.model.chat.ChatModel;
|
||||
import com.easyagents.core.model.chat.ChatOptions;
|
||||
import com.easyagents.core.model.chat.StreamResponseListener;
|
||||
import com.easyagents.core.model.chat.response.AiMessageResponse;
|
||||
import com.easyagents.core.model.client.StreamContext;
|
||||
import com.easyagents.core.prompt.MemoryPrompt;
|
||||
import org.apache.catalina.connector.ClientAbortException;
|
||||
import tech.easyflow.core.chat.protocol.ChatDomain;
|
||||
import tech.easyflow.core.chat.protocol.ChatEnvelope;
|
||||
import tech.easyflow.core.chat.protocol.ChatType;
|
||||
import tech.easyflow.core.chat.protocol.MessageRole;
|
||||
import tech.easyflow.core.chat.protocol.payload.ErrorPayload;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatStreamListener implements StreamResponseListener {
|
||||
|
||||
private final String conversationId;
|
||||
private final ChatModel chatModel;
|
||||
private final MemoryPrompt memoryPrompt;
|
||||
private final ChatSseEmitter sseEmitter;
|
||||
private final ChatOptions chatOptions;
|
||||
// 核心标记:是否允许执行onStop业务逻辑(仅最后一次无后续工具调用时为true)
|
||||
private boolean canStop = true;
|
||||
// 辅助标记:是否进入过工具调用(避免重复递归判断)
|
||||
private boolean hasToolCall = false;
|
||||
|
||||
public ChatStreamListener(String conversationId, ChatModel chatModel, MemoryPrompt memoryPrompt, ChatSseEmitter sseEmitter, ChatOptions chatOptions) {
|
||||
this.conversationId = conversationId;
|
||||
this.chatModel = chatModel;
|
||||
this.memoryPrompt = memoryPrompt;
|
||||
this.sseEmitter = sseEmitter;
|
||||
this.chatOptions = chatOptions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStart(StreamContext context) {
|
||||
StreamResponseListener.super.onStart(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(StreamContext context, AiMessageResponse aiMessageResponse) {
|
||||
try {
|
||||
AiMessage aiMessage = aiMessageResponse.getMessage();
|
||||
if (aiMessage == null) {
|
||||
return;
|
||||
}
|
||||
if (aiMessage.isFinalDelta() && aiMessageResponse.hasToolCalls()) {
|
||||
this.canStop = false; // 工具调用期间,禁止执行onStop
|
||||
this.hasToolCall = true; // 标记已进入过工具调用
|
||||
aiMessage.setContent(null);
|
||||
memoryPrompt.addMessage(aiMessage);
|
||||
List<ToolMessage> toolMessages = aiMessageResponse.executeToolCallsAndGetToolMessages();
|
||||
for (ToolMessage toolMessage : toolMessages) {
|
||||
memoryPrompt.addMessage(toolMessage);
|
||||
}
|
||||
chatModel.chatStream(memoryPrompt, this, chatOptions);
|
||||
} else {
|
||||
if (this.hasToolCall) {
|
||||
this.canStop = true;
|
||||
}
|
||||
String reasoningContent = aiMessage.getReasoningContent();
|
||||
if (reasoningContent != null && !reasoningContent.isEmpty()) {
|
||||
sendChatEnvelope(sseEmitter, reasoningContent, ChatType.THINKING);
|
||||
} else {
|
||||
String delta = aiMessage.getContent();
|
||||
if (delta != null && !delta.isEmpty()) {
|
||||
sendChatEnvelope(sseEmitter, delta, ChatType.MESSAGE);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStop(StreamContext context) {
|
||||
// 仅当canStop为true(最后一次无后续工具调用的响应)时,执行业务逻辑
|
||||
if (this.canStop) {
|
||||
System.out.println("onStop");
|
||||
if (context.getThrowable() != null) {
|
||||
sendSystemError(sseEmitter, context.getThrowable().getMessage());
|
||||
return;
|
||||
}
|
||||
memoryPrompt.addMessage(context.getFullMessage());
|
||||
ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>();
|
||||
chatEnvelope.setDomain(ChatDomain.SYSTEM);
|
||||
sseEmitter.sendDone(chatEnvelope);
|
||||
StreamResponseListener.super.onStop(context);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(StreamContext context, Throwable throwable) {
|
||||
if (throwable != null) {
|
||||
throwable.printStackTrace();
|
||||
sendSystemError(sseEmitter, throwable.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void sendChatEnvelope(ChatSseEmitter sseEmitter, String deltaContent, ChatType chatType) throws IOException {
|
||||
if (deltaContent == null || deltaContent.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>();
|
||||
chatEnvelope.setDomain(ChatDomain.LLM);
|
||||
chatEnvelope.setType(chatType);
|
||||
|
||||
Map<String, String> deltaMap = new LinkedHashMap<>();
|
||||
deltaMap.put("conversation_id", this.conversationId);
|
||||
deltaMap.put("role", MessageRole.ASSISTANT.getValue());
|
||||
deltaMap.put("delta", deltaContent);
|
||||
chatEnvelope.setPayload(deltaMap);
|
||||
|
||||
sseEmitter.send(chatEnvelope);
|
||||
}
|
||||
|
||||
public void sendSystemError(ChatSseEmitter sseEmitter,
|
||||
String message) {
|
||||
ChatEnvelope<ErrorPayload> envelope = new ChatEnvelope<>();
|
||||
ErrorPayload payload = new ErrorPayload();
|
||||
payload.setMessage(message);
|
||||
payload.setCode("SYSTEM_ERROR");
|
||||
payload.setRetryable(false);
|
||||
envelope.setPayload(payload);
|
||||
envelope.setDomain(ChatDomain.SYSTEM);
|
||||
envelope.setType(ChatType.ERROR);
|
||||
sseEmitter.sendError(envelope);
|
||||
sseEmitter.complete();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package tech.easyflow.ai.easyagents.listener;
|
||||
|
||||
import com.easyagents.core.model.chat.StreamResponseListener;
|
||||
import com.easyagents.core.model.chat.response.AiMessageResponse;
|
||||
import com.easyagents.core.model.client.StreamContext;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import tech.easyflow.common.util.StringUtil;
|
||||
import tech.easyflow.core.chat.protocol.ChatDomain;
|
||||
import tech.easyflow.core.chat.protocol.ChatEnvelope;
|
||||
import tech.easyflow.core.chat.protocol.ChatType;
|
||||
import tech.easyflow.core.chat.protocol.MessageRole;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 系统提示词优化监听器
|
||||
*/
|
||||
public class PromptChoreChatStreamListener implements StreamResponseListener {
|
||||
|
||||
private final ChatSseEmitter sseEmitter;
|
||||
|
||||
public PromptChoreChatStreamListener(ChatSseEmitter sseEmitter) {
|
||||
this.sseEmitter = sseEmitter;
|
||||
}
|
||||
@Override
|
||||
public void onStart(StreamContext context) {
|
||||
StreamResponseListener.super.onStart(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(StreamContext context, AiMessageResponse response) {
|
||||
String content = response.getMessage().getContent();
|
||||
if (content != null) {
|
||||
String delta = response.getMessage().getContent();
|
||||
if (StringUtil.hasText(delta)) {
|
||||
ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>();
|
||||
chatEnvelope.setDomain(ChatDomain.LLM);
|
||||
chatEnvelope.setType(ChatType.MESSAGE);
|
||||
Map<String, String> deletaMap = new HashMap<>();
|
||||
deletaMap.put("delta", delta);
|
||||
deletaMap.put("role", MessageRole.ASSISTANT.getValue());
|
||||
chatEnvelope.setPayload(deletaMap);
|
||||
sseEmitter.send(chatEnvelope);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStop(StreamContext context) {
|
||||
System.out.println("onStop");
|
||||
sseEmitter.sendDone(new ChatEnvelope<>());
|
||||
StreamResponseListener.super.onStop(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(StreamContext context, Throwable throwable) {
|
||||
StreamResponseListener.super.onFailure(context, throwable);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package tech.easyflow.ai.easyagents.memory;
|
||||
|
||||
import com.easyagents.core.memory.ChatMemory;
|
||||
import com.easyagents.core.message.Message;
|
||||
import com.mybatisflex.core.query.QueryWrapper;
|
||||
import tech.easyflow.ai.entity.BotMessage;
|
||||
import tech.easyflow.ai.service.BotMessageService;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
public class BotMessageMemory implements ChatMemory {
|
||||
private final BigInteger botId;
|
||||
private final BigInteger accountId;
|
||||
private final BigInteger conversationId;
|
||||
private final BotMessageService messageService;
|
||||
|
||||
public BotMessageMemory(BigInteger botId, BigInteger accountId, BigInteger conversationId,
|
||||
BotMessageService messageService) {
|
||||
this.botId = botId;
|
||||
this.accountId = accountId;
|
||||
this.conversationId = conversationId;
|
||||
this.messageService = messageService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> getMessages(int count) {
|
||||
List<BotMessage> sysAiMessages = messageService.list(QueryWrapper.create()
|
||||
.eq(BotMessage::getBotId, botId, true)
|
||||
.eq(BotMessage::getAccountId, accountId, true)
|
||||
.eq(BotMessage::getConversationId, conversationId, true)
|
||||
.orderBy(BotMessage::getCreated, true)
|
||||
.limit(count)
|
||||
);
|
||||
|
||||
if (sysAiMessages == null || sysAiMessages.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Message> messages = new ArrayList<>(sysAiMessages.size());
|
||||
for (BotMessage botMessage : sysAiMessages) {
|
||||
Message message = botMessage.getContentAsMessage();
|
||||
if (message != null) messages.add(message);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void addMessage(Message message) {
|
||||
|
||||
BotMessage dbMessage = new BotMessage();
|
||||
dbMessage.setCreated(new Date());
|
||||
dbMessage.setBotId(botId);
|
||||
dbMessage.setAccountId(accountId);
|
||||
dbMessage.setConversationId(conversationId);
|
||||
dbMessage.setContentAndRole(message);
|
||||
dbMessage.setModified(new Date());
|
||||
messageService.save(dbMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object id() {
|
||||
return botId;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package tech.easyflow.ai.easyagents.memory;
|
||||
|
||||
import com.easyagents.core.memory.DefaultChatMemory;
|
||||
import com.easyagents.core.message.*;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
import tech.easyflow.ai.entity.BotMessage;
|
||||
import tech.easyflow.core.chat.protocol.ChatDomain;
|
||||
import tech.easyflow.core.chat.protocol.ChatEnvelope;
|
||||
import tech.easyflow.core.chat.protocol.ChatType;
|
||||
import tech.easyflow.core.chat.protocol.MessageRole;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
public class DefaultBotMessageMemory extends DefaultChatMemory {
|
||||
|
||||
private final ChatSseEmitter sseEmitter;
|
||||
|
||||
private final List<Map<String, String>> messages;
|
||||
public DefaultBotMessageMemory(BigInteger conversationId, ChatSseEmitter sseEmitter, List<Map<String, String>> messages) {
|
||||
super(conversationId);
|
||||
this.sseEmitter = sseEmitter;
|
||||
this.messages = messages;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> getMessages(int count) {
|
||||
List<Message> list = new ArrayList<>(messages.size());
|
||||
for (Map<String, String> msg : messages) {
|
||||
BotMessage botMessage = new BotMessage();
|
||||
botMessage.setRole(msg.get("role"));
|
||||
botMessage.setContent(msg.get("content"));
|
||||
Message message = botMessage.getContentAsMessage();
|
||||
list.add(message);
|
||||
}
|
||||
List<Message> collect = list.stream()
|
||||
.limit(count)
|
||||
.collect(Collectors.toList());
|
||||
return collect;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addMessage(Message message) {
|
||||
BotMessage dbMessage = new BotMessage();
|
||||
ChatEnvelope<Map<String, String>> chatEnvelope = new ChatEnvelope<>();
|
||||
String jsonMessage = JSON.toJSONString(message, SerializerFeature.WriteClassName);
|
||||
if (message instanceof AiMessage) {
|
||||
dbMessage.setRole(MessageRole.ASSISTANT.getValue());
|
||||
|
||||
} else if (message instanceof UserMessage) {
|
||||
dbMessage.setRole(MessageRole.USER.getValue());
|
||||
} else if (message instanceof SystemMessage) {
|
||||
dbMessage.setRole(MessageRole.SYSTEM.getValue());
|
||||
} else if (message instanceof ToolMessage) {
|
||||
dbMessage.setRole(MessageRole.TOOL.getValue());
|
||||
}
|
||||
Map<String, String> res = new HashMap<>();
|
||||
res.put("role", dbMessage.getRole());
|
||||
res.put("content", jsonMessage);
|
||||
chatEnvelope.setType(ChatType.MESSAGE);
|
||||
chatEnvelope.setPayload(res);
|
||||
chatEnvelope.setDomain(ChatDomain.SYSTEM);
|
||||
if (dbMessage.getRole().equals(MessageRole.USER.getValue())) {
|
||||
messages.remove(messages.size() - 1);
|
||||
}
|
||||
sseEmitter.sendMessageNeedSave(chatEnvelope);
|
||||
messages.add(res);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package tech.easyflow.ai.easyagents.memory;
|
||||
|
||||
import com.easyagents.core.memory.DefaultChatMemory;
|
||||
import com.easyagents.core.message.Message;
|
||||
import tech.easyflow.core.chat.protocol.sse.ChatSseEmitter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class PublicBotMessageMemory extends DefaultChatMemory {
|
||||
private final ChatSseEmitter sseEmitter;
|
||||
private List<Message> messages = new ArrayList<>();
|
||||
|
||||
public PublicBotMessageMemory(ChatSseEmitter sseEmitter, List<Message> messages ) {
|
||||
this.messages = messages;
|
||||
this.sseEmitter = sseEmitter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> getMessages(int count) {
|
||||
return messages.stream()
|
||||
.limit(count)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addMessage(Message message) {
|
||||
this.messages.add(message);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package tech.easyflow.ai.easyagents.tool;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.model.chat.tool.BaseTool;
|
||||
import com.easyagents.core.model.chat.tool.Parameter;
|
||||
import tech.easyflow.ai.entity.DocumentCollection;
|
||||
import tech.easyflow.ai.service.DocumentCollectionService;
|
||||
import tech.easyflow.common.util.SpringContextUtil;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class DocumentCollectionTool extends BaseTool {
|
||||
|
||||
private BigInteger knowledgeId;
|
||||
|
||||
public DocumentCollectionTool() {
|
||||
}
|
||||
|
||||
public DocumentCollectionTool(DocumentCollection documentCollection, boolean needEnglishName) {
|
||||
this.knowledgeId = documentCollection.getId();
|
||||
if (needEnglishName) {
|
||||
this.name = documentCollection.getEnglishName();
|
||||
} else {
|
||||
this.name = documentCollection.getTitle();
|
||||
}
|
||||
this.description = documentCollection.getDescription();
|
||||
this.parameters = getDefaultParameters();
|
||||
}
|
||||
|
||||
|
||||
public Parameter[] getDefaultParameters() {
|
||||
Parameter parameter = new Parameter();
|
||||
parameter.setName("input");
|
||||
parameter.setDescription("要查询的相关知识");
|
||||
parameter.setType("string");
|
||||
parameter.setRequired(true);
|
||||
return new Parameter[]{parameter};
|
||||
}
|
||||
|
||||
public BigInteger getKnowledgeId() {
|
||||
return knowledgeId;
|
||||
}
|
||||
|
||||
public void setKnowledgeId(BigInteger knowledgeId) {
|
||||
this.knowledgeId = knowledgeId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Map<String, Object> argsMap) {
|
||||
|
||||
DocumentCollectionService knowledgeService = SpringContextUtil.getBean(DocumentCollectionService.class);
|
||||
List<Document> documents = knowledgeService.search(this.knowledgeId, (String) argsMap.get("input"));
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (documents != null) {
|
||||
for (Document document : documents) {
|
||||
sb.append(document.getContent());
|
||||
}
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package tech.easyflow.ai.easyagents.tool;
|
||||
|
||||
import com.easyagents.core.model.chat.tool.BaseTool;
|
||||
import com.easyagents.core.model.chat.tool.Tool;
|
||||
import com.easyagents.mcp.client.McpClientManager;
|
||||
import tech.easyflow.ai.entity.Mcp;
|
||||
import tech.easyflow.ai.service.McpService;
|
||||
import tech.easyflow.ai.service.impl.McpServiceImpl;
|
||||
import tech.easyflow.common.util.SpringContextUtil;
|
||||
import tech.easyflow.common.util.StringUtil;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
public class McpTool extends BaseTool {
|
||||
private BigInteger mcpId;
|
||||
|
||||
@Override
|
||||
public Object invoke(Map<String, Object> argsMap) {
|
||||
return runMcp(this.mcpId, argsMap);
|
||||
}
|
||||
|
||||
public Object runMcp(BigInteger mcpId, Map<String, Object> argsMap) {
|
||||
|
||||
McpService mcpService = SpringContextUtil.getBean(McpService.class);
|
||||
Mcp mcp = mcpService.getMapper().selectOneById(mcpId);
|
||||
String serverName = McpServiceImpl.getFirstMcpServerName(mcp.getConfigJson());
|
||||
if (StringUtil.hasText(serverName)) {
|
||||
McpClientManager mcpClientManager = McpClientManager.getInstance();
|
||||
Tool mcpTool = mcpClientManager.getMcpTool(serverName, this.name);
|
||||
Object result = mcpTool.invoke(argsMap);
|
||||
return result;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public void setMcpId(BigInteger mcpId) {
|
||||
this.mcpId = mcpId;
|
||||
}
|
||||
|
||||
public BigInteger getMcpId() {
|
||||
return mcpId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,368 @@
|
||||
package tech.easyflow.ai.easyagents.tool;
|
||||
|
||||
import cn.hutool.core.io.FileTypeUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import com.easyagents.core.model.chat.tool.BaseTool;
|
||||
import com.easyagents.core.model.chat.tool.Parameter;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.mybatisflex.core.query.QueryWrapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import tech.easyflow.ai.easyagents.CustomMultipartFile;
|
||||
import tech.easyflow.ai.entity.Plugin;
|
||||
import tech.easyflow.ai.entity.PluginItem;
|
||||
import tech.easyflow.ai.mapper.PluginMapper;
|
||||
import tech.easyflow.ai.service.PluginItemService;
|
||||
import tech.easyflow.common.ai.plugin.NestedParamConverter;
|
||||
import tech.easyflow.common.ai.plugin.PluginHttpClient;
|
||||
import tech.easyflow.common.ai.plugin.PluginParam;
|
||||
import tech.easyflow.common.ai.plugin.PluginParamConverter;
|
||||
import tech.easyflow.common.filestorage.FileStorageManager;
|
||||
import tech.easyflow.common.filestorage.FileStorageService;
|
||||
import tech.easyflow.common.util.SpringContextUtil;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Array;
|
||||
import java.math.BigInteger;
|
||||
import java.util.*;
|
||||
|
||||
public class PluginTool extends BaseTool {
|
||||
|
||||
// 插件工具id
|
||||
private BigInteger pluginToolId;
|
||||
private String name;
|
||||
private String description;
|
||||
private Parameter[] parameters;
|
||||
private static final Logger logger = LoggerFactory.getLogger(PluginTool.class);
|
||||
|
||||
public PluginTool() {
|
||||
|
||||
}
|
||||
|
||||
public PluginTool(PluginItem pluginItem) {
|
||||
this.name = pluginItem.getEnglishName();
|
||||
this.description = pluginItem.getDescription();
|
||||
this.pluginToolId = pluginItem.getId();
|
||||
this.parameters = getDefaultParameters(pluginItem.getInputData());
|
||||
}
|
||||
|
||||
public BigInteger getPluginToolId() {
|
||||
return pluginToolId;
|
||||
}
|
||||
|
||||
public void setPluginToolId(BigInteger pluginToolId) {
|
||||
this.pluginToolId = pluginToolId;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public void setDescription(String description) {
|
||||
this.description = description;
|
||||
}
|
||||
|
||||
public void setParameters(Parameter[] parameters) {
|
||||
this.parameters = parameters;
|
||||
}
|
||||
|
||||
private Plugin getAiPlugin(BigInteger pluginId) {
|
||||
QueryWrapper queryWrapper = QueryWrapper.create()
|
||||
.select("*")
|
||||
.from("tb_plugin")
|
||||
.where("id = ?", pluginId);
|
||||
PluginMapper pluginMapper = SpringContextUtil.getBean(PluginMapper.class);
|
||||
Plugin plugin1 = pluginMapper.selectOneByQuery(queryWrapper);
|
||||
return plugin1;
|
||||
}
|
||||
|
||||
private Parameter[] getDefaultParameters(String inputData) {
|
||||
PluginItemService pluginToolService = SpringContextUtil.getBean(PluginItemService.class);
|
||||
QueryWrapper queryAiPluginToolWrapper = QueryWrapper.create()
|
||||
.select("*")
|
||||
.from("tb_plugin_item")
|
||||
.where("id = ? ", this.pluginToolId);
|
||||
PluginItem pluginItem = pluginToolService.getMapper().selectOneByQuery(queryAiPluginToolWrapper);
|
||||
List<Map<String, Object>> dataList = null;
|
||||
if (pluginItem == null || pluginItem.getInputData() == null){
|
||||
dataList = getDataList(inputData);
|
||||
} else {
|
||||
dataList = getDataList(pluginItem.getInputData());
|
||||
}
|
||||
Parameter[] params = new Parameter[dataList.size()];
|
||||
for (int i = 0; i < dataList.size(); i++) {
|
||||
Map<String, Object> item = dataList.get(i);
|
||||
Parameter parameter = new Parameter();
|
||||
parameter.setName((String) item.get("name"));
|
||||
parameter.setDescription((String) item.get("description"));
|
||||
parameter.setRequired((boolean) item.get("required"));
|
||||
String type = (String) item.get("type");
|
||||
if (type != null) {
|
||||
parameter.setType(type.toLowerCase());
|
||||
}
|
||||
params[i] = parameter;
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
// 转换输入参数
|
||||
private List<Map<String, Object>> getDataList(String jsonArray){
|
||||
List<Map<String, Object>> dataList;
|
||||
if (jsonArray == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
try {
|
||||
dataList = new ObjectMapper().readValue(
|
||||
jsonArray,
|
||||
new TypeReference<List<Map<String, Object>>>(){}
|
||||
);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return dataList;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getDescription() {
|
||||
return description;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Parameter[] getParameters() {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Map<String, Object> argsMap) {
|
||||
return runPluginTool(argsMap, null, this.pluginToolId);
|
||||
}
|
||||
|
||||
public Object runPluginTool(Map<String, Object> argsMap, String inputData, BigInteger pluginId){
|
||||
PluginItemService pluginToolService = SpringContextUtil.getBean(PluginItemService.class);
|
||||
QueryWrapper queryAiPluginToolWrapper = QueryWrapper.create()
|
||||
.select("*")
|
||||
.from("tb_plugin_item")
|
||||
.where("id = ? ", pluginId);
|
||||
PluginItem pluginItem = pluginToolService.getMapper().selectOneByQuery(queryAiPluginToolWrapper);
|
||||
String method = pluginItem.getRequestMethod().toUpperCase();
|
||||
Plugin plugin = getAiPlugin(pluginItem.getPluginId());
|
||||
|
||||
String url;
|
||||
if (!StrUtil.isEmpty(pluginItem.getBasePath())) {
|
||||
url = plugin.getBaseUrl()+ pluginItem.getBasePath();
|
||||
} else {
|
||||
url = plugin.getBaseUrl()+"/"+ pluginItem.getName();
|
||||
}
|
||||
|
||||
List<Map<String, Object>> headers = getDataList(plugin.getHeaders());
|
||||
Map<String, Object> headersMap = new HashMap<>();
|
||||
for (Map<String, Object> header : headers) {
|
||||
headersMap.put((String) header.get("label"), header.get("value"));
|
||||
}
|
||||
List<PluginParam> params = new ArrayList<>();
|
||||
|
||||
String authType = plugin.getAuthType();
|
||||
if (!StrUtil.isEmpty(authType) && "apiKey".equals(plugin.getAuthType())){
|
||||
if ("headers".equals(plugin.getPosition())){
|
||||
headersMap.put(plugin.getTokenKey(), plugin.getTokenValue());
|
||||
} else {
|
||||
PluginParam pluginParam = new PluginParam();
|
||||
pluginParam.setName(plugin.getTokenKey());
|
||||
pluginParam.setDefaultValue(plugin.getTokenValue());
|
||||
pluginParam.setEnabled(true);
|
||||
pluginParam.setRequired(true);
|
||||
pluginParam.setMethod("query");
|
||||
params.add(pluginParam);
|
||||
}
|
||||
}
|
||||
List<PluginParam> pluginParams = null;
|
||||
// 前端点击试运行传过来的参数
|
||||
if (inputData != null && !inputData.isEmpty()){
|
||||
pluginParams = PluginParamConverter.convertFromJson(inputData);
|
||||
// 大模型命中funcation_call 调用参数
|
||||
} else {
|
||||
pluginParams = PluginParamConverter.convertFromJson(pluginItem.getInputData());
|
||||
}
|
||||
|
||||
// 准备存放不同位置的参数
|
||||
List<PluginParam> queryParams = new ArrayList<>();
|
||||
List<PluginParam> bodyParams = new ArrayList<>();
|
||||
List<PluginParam> headerParams = new ArrayList<>();
|
||||
List<PluginParam> pathParams = new ArrayList<>();
|
||||
Map<String, Object> nestedParams = NestedParamConverter.convertToNestedParamMap(pluginParams);
|
||||
|
||||
// 遍历嵌套参数
|
||||
for (Map.Entry<String, Object> entry : nestedParams.entrySet()) {
|
||||
String paramName = entry.getKey();
|
||||
|
||||
// 获取原始参数定义
|
||||
PluginParam originalParam = findOriginalParam(pluginParams, paramName);
|
||||
if (originalParam == null || !originalParam.isEnabled()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 创建参数副本以避免修改原始定义
|
||||
PluginParam requestParam = new PluginParam();
|
||||
requestParam.setName(originalParam.getName());
|
||||
requestParam.setDescription(originalParam.getDescription());
|
||||
requestParam.setRequired(originalParam.isRequired());
|
||||
|
||||
requestParam.setEnabled(originalParam.isEnabled());
|
||||
requestParam.setMethod(originalParam.getMethod());
|
||||
requestParam.setChildren(originalParam.getChildren());
|
||||
// 优先级: argsMap值 < 参数默认值
|
||||
if (argsMap != null && argsMap.containsKey(paramName)) {
|
||||
// 1. 优先检查是否有有效的默认值
|
||||
if (hasValidDefaultValue(originalParam.getDefaultValue())) {
|
||||
// 使用默认值
|
||||
requestParam.setDefaultValue(originalParam.getDefaultValue());
|
||||
} else {
|
||||
// 使用大模型返回的值
|
||||
requestParam.setDefaultValue(argsMap.get(paramName));
|
||||
}
|
||||
} else if (hasValidDefaultValue(originalParam.getDefaultValue())) {
|
||||
// 2. 没有传参但默认值有效时使用默认值
|
||||
// 如果是文件类型
|
||||
if (originalParam.getType().equals("File")){
|
||||
try {
|
||||
FileStorageService fileStorageService = SpringContextUtil.getBean(FileStorageManager.class);
|
||||
InputStream inputStream = fileStorageService.readStream((String)originalParam.getDefaultValue());
|
||||
requestParam.setType("MultipartFile");
|
||||
byte[] bytes = inputStreamToBytes(inputStream);
|
||||
String contentType = FileTypeUtil.getType(new ByteArrayInputStream(bytes));
|
||||
String fileUrl = (String) originalParam.getDefaultValue();
|
||||
int lastSlashIndex = fileUrl.lastIndexOf("/");
|
||||
String fileName = fileUrl.substring(lastSlashIndex + 1);
|
||||
requestParam.setDefaultValue(new CustomMultipartFile(bytes, originalParam.getName(), fileName, contentType));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
} else {
|
||||
requestParam.setType(originalParam.getType());
|
||||
requestParam.setDefaultValue(originalParam.getDefaultValue());
|
||||
}
|
||||
}
|
||||
// 根据method分类参数
|
||||
switch (originalParam.getMethod().toLowerCase()) {
|
||||
case "query":
|
||||
queryParams.add(requestParam);
|
||||
break;
|
||||
case "body":
|
||||
bodyParams.add(requestParam);
|
||||
break;
|
||||
case "header":
|
||||
headerParams.add(requestParam);
|
||||
break;
|
||||
case "path":
|
||||
pathParams.add(requestParam);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 合并所有参数
|
||||
List<PluginParam> allParams = new ArrayList<>();
|
||||
allParams.addAll(pathParams);
|
||||
allParams.addAll(queryParams);
|
||||
allParams.addAll(bodyParams);
|
||||
allParams.addAll(headerParams);
|
||||
allParams.addAll(params);
|
||||
|
||||
// 发送请求
|
||||
JSONObject result = PluginHttpClient.sendRequest(url, method, headersMap, allParams);
|
||||
if (result.get("error") != null){
|
||||
logger.error("插件调用失败");
|
||||
logger.error(result.get("error").toString());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// 辅助方法:根据参数名查找原始参数定义
|
||||
private PluginParam findOriginalParam(List<PluginParam> params, String name) {
|
||||
for (PluginParam param : params) {
|
||||
if (name.equals(param.getName())) {
|
||||
return param;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// 添加辅助方法判断默认值是否有效
|
||||
private boolean hasValidDefaultValue(Object defaultValue) {
|
||||
if (defaultValue == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 字符串类型检查
|
||||
if (defaultValue instanceof CharSequence) {
|
||||
return !((CharSequence) defaultValue).toString().trim().isEmpty();
|
||||
}
|
||||
|
||||
// 集合/数组类型检查
|
||||
if (defaultValue instanceof Collection) {
|
||||
return !((Collection<?>) defaultValue).isEmpty();
|
||||
}
|
||||
if (defaultValue instanceof Map) {
|
||||
return !((Map<?, ?>) defaultValue).isEmpty();
|
||||
}
|
||||
if (defaultValue.getClass().isArray()) {
|
||||
return Array.getLength(defaultValue) > 0;
|
||||
}
|
||||
|
||||
// 其他类型直接认为有效
|
||||
return true;
|
||||
}
|
||||
|
||||
private void processParamWithChildren(Map<String, Object> paramDef, Map<String, Object> argsMap, List<PluginParam> params) {
|
||||
boolean enabled = (boolean) paramDef.get("enabled");
|
||||
if (!enabled){
|
||||
return;
|
||||
}
|
||||
String paramName = (String) paramDef.get("name");
|
||||
PluginParam pluginParam = new PluginParam();
|
||||
pluginParam.setName(paramName);
|
||||
pluginParam.setDescription((String) paramDef.get("description"));
|
||||
pluginParam.setRequired((boolean) paramDef.get("required"));
|
||||
pluginParam.setType((String) paramDef.get("type"));
|
||||
pluginParam.setEnabled((boolean) paramDef.get("enabled"));
|
||||
pluginParam.setMethod((String) paramDef.get("method"));
|
||||
|
||||
// 如果用户传了值,就用用户的值;否则用默认值
|
||||
if (paramDef.get("defaultValue") != null && !"".equals(paramDef.get("defaultValue"))) {
|
||||
pluginParam.setDefaultValue(paramDef.get("defaultValue"));
|
||||
} else if (argsMap != null && paramDef.get("name").equals(paramName) && paramDef.get("defaultValue") != null) {
|
||||
pluginParam.setDefaultValue(argsMap.get(paramName));
|
||||
}
|
||||
|
||||
params.add(pluginParam);
|
||||
|
||||
// 处理 children
|
||||
List<Map<String, Object>> children = (List<Map<String, Object>>) paramDef.get("children");
|
||||
if (children != null) {
|
||||
for (Map<String, Object> child : children) {
|
||||
processParamWithChildren(child, argsMap, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static byte[] inputStreamToBytes(InputStream inputStream) throws IOException {
|
||||
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
|
||||
int nRead;
|
||||
byte[] data = new byte[1024]; // 1KB缓冲区
|
||||
|
||||
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
|
||||
buffer.write(data, 0, nRead);
|
||||
}
|
||||
|
||||
buffer.flush();
|
||||
return buffer.toByteArray();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package tech.easyflow.ai.easyagents.tool;
|
||||
|
||||
import com.easyagents.core.model.chat.tool.BaseTool;
|
||||
import com.easyagents.core.model.chat.tool.Parameter;
|
||||
import com.easyagents.flow.core.chain.ChainDefinition;
|
||||
import com.easyagents.flow.core.chain.DataType;
|
||||
import com.easyagents.flow.core.chain.runtime.ChainExecutor;
|
||||
import tech.easyflow.ai.entity.Workflow;
|
||||
import tech.easyflow.common.util.SpringContextUtil;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class WorkflowTool extends BaseTool {
|
||||
|
||||
private BigInteger workflowId;
|
||||
|
||||
public WorkflowTool() {
|
||||
}
|
||||
|
||||
public WorkflowTool(Workflow workflow, boolean needEnglishName) {
|
||||
this.workflowId = workflow.getId();
|
||||
if (needEnglishName) {
|
||||
this.name = workflow.getEnglishName();
|
||||
} else {
|
||||
this.name = workflow.getTitle();
|
||||
}
|
||||
this.description = workflow.getDescription();
|
||||
this.parameters = toParameters(workflow);
|
||||
}
|
||||
|
||||
|
||||
static Parameter[] toParameters(Workflow workflow) {
|
||||
ChainExecutor executor = SpringContextUtil.getBean(ChainExecutor.class);
|
||||
ChainDefinition definition = executor.getDefinitionRepository().getChainDefinitionById(workflow.getId().toString());
|
||||
List<com.easyagents.flow.core.chain.Parameter> parameterDefs = definition.getStartParameters();
|
||||
if (parameterDefs == null || parameterDefs.isEmpty()) {
|
||||
return new Parameter[0];
|
||||
}
|
||||
|
||||
Parameter[] parameters = new Parameter[parameterDefs.size()];
|
||||
for (int i = 0; i < parameterDefs.size(); i++) {
|
||||
com.easyagents.flow.core.chain.Parameter parameterDef = parameterDefs.get(i);
|
||||
Parameter parameter = new Parameter();
|
||||
parameter.setName(parameterDef.getName());
|
||||
parameter.setDescription(parameterDef.getDescription());
|
||||
DataType dataType = parameterDef.getDataType();
|
||||
if (dataType == null) dataType = DataType.String;
|
||||
parameter.setType(dataType.toString());
|
||||
parameter.setRequired(parameterDef.isRequired());
|
||||
parameters[i] = parameter;
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
public BigInteger getWorkflowId() {
|
||||
return workflowId;
|
||||
}
|
||||
|
||||
public void setWorkflowId(BigInteger workflowId) {
|
||||
this.workflowId = workflowId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Map<String, Object> argsMap) {
|
||||
ChainExecutor executor = SpringContextUtil.getBean(ChainExecutor.class);
|
||||
return executor.execute(workflowId.toString(), argsMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "AiWorkflowFunction{" +
|
||||
"workflowId=" + workflowId +
|
||||
", name='" + name + '\'' +
|
||||
", description='" + description + '\'' +
|
||||
", parameters=" + Arrays.toString(parameters) +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user