From f57544daa21ca759e632698fc18ca4373d51672f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AD=90=E9=BB=98?= <925456043@qq.com> Date: Sun, 5 Apr 2026 20:22:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=8B=E6=B2=89=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E6=A3=80=E7=B4=A2=E7=BC=96=E6=8E=92=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 rag retrieval 核心协议、RRF 融合与相关度归一化 - 支持关键词检索按 knowledgeId 过滤并补充 ES/Lucene 单测 - 扩展 KnowledgeNode 检索模式与 Milvus 检索参数透传 --- .../flow/core/node/KnowledgeNode.java | 10 + .../core/parser/impl/KnowledgeNodeParser.java | 1 + .../easy-agents-rag-retrieval/pom.xml | 9 + .../rag/retrieval/FusionStrategy.java | 8 + .../easyagents/rag/retrieval/HitSource.java | 17 ++ .../rag/retrieval/KeywordRetriever.java | 8 + .../com/easyagents/rag/retrieval/RagHit.java | 198 ++++++++++++++++++ .../easyagents/rag/retrieval/RagQuery.java | 53 +++++ .../rag/retrieval/RagRetrievalExecutor.java | 168 +++++++++++++++ .../retrieval/RagRetrievalMetadataKeys.java | 16 ++ .../rag/retrieval/RagRetrievalResult.java | 26 +++ .../rag/retrieval/RagScoreNormalizer.java | 117 +++++++++++ .../rag/retrieval/RetrievalMode.java | 19 ++ .../rag/retrieval/RrfFusionStrategy.java | 121 +++++++++++ .../rag/retrieval/VectorRetriever.java | 8 + .../retrieval/RagRetrievalExecutorTest.java | 111 ++++++++++ .../rag/retrieval/RagScoreNormalizerTest.java | 69 ++++++ .../rag/retrieval/RrfFusionStrategyTest.java | 46 ++++ .../easy-agents-search-engine-es/pom.xml | 5 + .../easyagents/engine/es/ElasticSearcher.java | 131 +++++++++--- .../es/ElasticSearcherQueryBuilderTest.java | 54 +++++ .../easy-agents-search-engine-lucene/pom.xml | 5 + .../search/engine/lucene/LuceneSearcher.java | 31 ++- .../engine/lucene/LuceneSearcherTest.java | 45 ++++ .../engine/service/DocumentSearcher.java | 8 +- .../service/KeywordSearchMetadataKeys.java | 9 + .../engine/service/KeywordSearchRequest.java | 39 ++++ .../store/milvus/MilvusVectorStore.java | 11 + 28 files changed, 1309 insertions(+), 34 deletions(-) create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/FusionStrategy.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/HitSource.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/KeywordRetriever.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagHit.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagQuery.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalExecutor.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalMetadataKeys.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalResult.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagScoreNormalizer.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RetrievalMode.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RrfFusionStrategy.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/VectorRetriever.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagRetrievalExecutorTest.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagScoreNormalizerTest.java create mode 100644 easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RrfFusionStrategyTest.java create mode 100644 easy-agents-search-engine/easy-agents-search-engine-es/src/test/java/com/easyagents/engine/es/ElasticSearcherQueryBuilderTest.java create mode 100644 easy-agents-search-engine/easy-agents-search-engine-lucene/src/test/java/com/easyagents/search/engine/lucene/LuceneSearcherTest.java create mode 100644 easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchMetadataKeys.java create mode 100644 easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchRequest.java diff --git a/easy-agents-flow/src/main/java/com/easyagents/flow/core/node/KnowledgeNode.java b/easy-agents-flow/src/main/java/com/easyagents/flow/core/node/KnowledgeNode.java index 1b9c7ae..eb07bd8 100644 --- a/easy-agents-flow/src/main/java/com/easyagents/flow/core/node/KnowledgeNode.java +++ b/easy-agents-flow/src/main/java/com/easyagents/flow/core/node/KnowledgeNode.java @@ -35,6 +35,7 @@ public class KnowledgeNode extends BaseNode { private Object knowledgeId; private String keyword; private String limit; + private String retrievalMode = "HYBRID"; public Object getKnowledgeId() { return knowledgeId; @@ -60,6 +61,14 @@ public class KnowledgeNode extends BaseNode { this.limit = limit; } + public String getRetrievalMode() { + return retrievalMode; + } + + public void setRetrievalMode(String retrievalMode) { + this.retrievalMode = StringUtil.hasText(retrievalMode) ? retrievalMode : "HYBRID"; + } + @Override public Map execute(Chain chain) { Map argsMap = chain.getState().resolveParameters(this); @@ -90,6 +99,7 @@ public class KnowledgeNode extends BaseNode { "knowledgeId=" + knowledgeId + ", keyword='" + keyword + '\'' + ", limit='" + limit + '\'' + + ", retrievalMode='" + retrievalMode + '\'' + ", parameters=" + parameters + ", outputDefs=" + outputDefs + ", id='" + id + '\'' + diff --git a/easy-agents-flow/src/main/java/com/easyagents/flow/core/parser/impl/KnowledgeNodeParser.java b/easy-agents-flow/src/main/java/com/easyagents/flow/core/parser/impl/KnowledgeNodeParser.java index 60ee056..906e88e 100644 --- a/easy-agents-flow/src/main/java/com/easyagents/flow/core/parser/impl/KnowledgeNodeParser.java +++ b/easy-agents-flow/src/main/java/com/easyagents/flow/core/parser/impl/KnowledgeNodeParser.java @@ -27,6 +27,7 @@ public class KnowledgeNodeParser extends BaseNodeParser { knowledgeNode.setKnowledgeId(data.get("knowledgeId")); knowledgeNode.setLimit(data.getString("limit")); knowledgeNode.setKeyword(data.getString("keyword")); + knowledgeNode.setRetrievalMode(data.getString("retrievalMode")); return knowledgeNode; } diff --git a/easy-agents-rag/easy-agents-rag-retrieval/pom.xml b/easy-agents-rag/easy-agents-rag-retrieval/pom.xml index d7ebb38..914ee08 100644 --- a/easy-agents-rag/easy-agents-rag-retrieval/pom.xml +++ b/easy-agents-rag/easy-agents-rag-retrieval/pom.xml @@ -24,6 +24,10 @@ com.easyagents easy-agents-core + + com.easyagents + easy-agents-search-engine-service + com.easyagents easy-agents-rag-core @@ -32,5 +36,10 @@ com.easyagents easy-agents-rag-enhance + + junit + junit + test + diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/FusionStrategy.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/FusionStrategy.java new file mode 100644 index 0000000..41146ff --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/FusionStrategy.java @@ -0,0 +1,8 @@ +package com.easyagents.rag.retrieval; + +import java.util.List; + +public interface FusionStrategy { + + List fuse(List vectorHits, List keywordHits, int topK); +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/HitSource.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/HitSource.java new file mode 100644 index 0000000..dd723a7 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/HitSource.java @@ -0,0 +1,17 @@ +package com.easyagents.rag.retrieval; + +public enum HitSource { + VECTOR, + KEYWORD, + BOTH; + + public static HitSource merge(HitSource current, HitSource incoming) { + if (current == null) { + return incoming; + } + if (incoming == null || current == incoming) { + return current; + } + return BOTH; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/KeywordRetriever.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/KeywordRetriever.java new file mode 100644 index 0000000..0ef219e --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/KeywordRetriever.java @@ -0,0 +1,8 @@ +package com.easyagents.rag.retrieval; + +import java.util.List; + +public interface KeywordRetriever { + + List retrieve(RagQuery query); +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagHit.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagHit.java new file mode 100644 index 0000000..c1f8b88 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagHit.java @@ -0,0 +1,198 @@ +package com.easyagents.rag.retrieval; + +import com.easyagents.core.document.Document; + +import java.util.HashMap; +import java.util.Map; + +public class RagHit { + + private Object documentId; + private String title; + private String content; + private Double score; + private HitSource hitSource; + private Double vectorScore; + private Double keywordScore; + private Integer rank; + private Map metadata = new HashMap(); + + public static RagHit fromDocument(Document document) { + if (document == null) { + return null; + } + RagHit hit = new RagHit(); + hit.setDocumentId(document.getId()); + hit.setTitle(document.getTitle()); + hit.setContent(document.getContent()); + hit.setScore(document.getScore()); + hit.setMetadata(document.getMetadataMap()); + + Object hitSource = hit.getMetadata().get(RagRetrievalMetadataKeys.HIT_SOURCE); + if (hitSource instanceof String) { + hit.setHitSource(HitSource.valueOf(String.valueOf(hitSource))); + } + hit.setVectorScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.VECTOR_SCORE))); + hit.setKeywordScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.KEYWORD_SCORE))); + hit.setRank(asInteger(hit.getMetadata().get(RagRetrievalMetadataKeys.FINAL_RANK))); + return hit; + } + + public static RagHit fromDocument(Document document, HitSource defaultSource) { + RagHit hit = fromDocument(document); + if (hit != null && hit.getHitSource() == null) { + hit.setHitSource(defaultSource); + } + if (hit != null && defaultSource == HitSource.VECTOR && hit.getVectorScore() == null) { + hit.setVectorScore(document.getScore()); + } + if (hit != null && defaultSource == HitSource.KEYWORD && hit.getKeywordScore() == null) { + hit.setKeywordScore(document.getScore()); + } + return hit; + } + + public Document toDocument() { + Document document = new Document(); + document.setId(documentId); + document.setTitle(title); + document.setContent(content); + document.setScore(score); + Map metadataMap = metadata == null + ? new HashMap() + : new HashMap(metadata); + if (hitSource != null) { + metadataMap.put(RagRetrievalMetadataKeys.HIT_SOURCE, hitSource.name()); + } + if (vectorScore != null) { + metadataMap.put(RagRetrievalMetadataKeys.VECTOR_SCORE, vectorScore); + } + if (keywordScore != null) { + metadataMap.put(RagRetrievalMetadataKeys.KEYWORD_SCORE, keywordScore); + } + if (!metadataMap.containsKey(RagRetrievalMetadataKeys.FUSION_SCORE) + && hitSource == HitSource.BOTH + && score != null) { + metadataMap.put(RagRetrievalMetadataKeys.FUSION_SCORE, score); + } + if (rank != null) { + metadataMap.put(RagRetrievalMetadataKeys.FINAL_RANK, rank); + } + document.setMetadataMap(metadataMap); + return document; + } + + public RagHit copy() { + RagHit copy = new RagHit(); + copy.setDocumentId(documentId); + copy.setTitle(title); + copy.setContent(content); + copy.setScore(score); + copy.setHitSource(hitSource); + copy.setVectorScore(vectorScore); + copy.setKeywordScore(keywordScore); + copy.setRank(rank); + copy.setMetadata(metadata); + return copy; + } + + public Object getDocumentId() { + return documentId; + } + + public void setDocumentId(Object documentId) { + this.documentId = documentId; + } + + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public Double getScore() { + return score; + } + + public void setScore(Double score) { + this.score = score; + } + + public HitSource getHitSource() { + return hitSource; + } + + public void setHitSource(HitSource hitSource) { + this.hitSource = hitSource; + } + + public Double getVectorScore() { + return vectorScore; + } + + public void setVectorScore(Double vectorScore) { + this.vectorScore = vectorScore; + } + + public Double getKeywordScore() { + return keywordScore; + } + + public void setKeywordScore(Double keywordScore) { + this.keywordScore = keywordScore; + } + + public Integer getRank() { + return rank; + } + + public void setRank(Integer rank) { + this.rank = rank; + } + + public Map getMetadata() { + return metadata; + } + + public void setMetadata(Map metadata) { + this.metadata = metadata == null ? new HashMap() : new HashMap(metadata); + } + + private static Double asDouble(Object value) { + if (value instanceof Number) { + return ((Number) value).doubleValue(); + } + if (value instanceof String) { + try { + return Double.valueOf((String) value); + } catch (NumberFormatException ignore) { + return null; + } + } + return null; + } + + private static Integer asInteger(Object value) { + if (value instanceof Number) { + return ((Number) value).intValue(); + } + if (value instanceof String) { + try { + return Integer.valueOf((String) value); + } catch (NumberFormatException ignore) { + return null; + } + } + return null; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagQuery.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagQuery.java new file mode 100644 index 0000000..31fa8f6 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagQuery.java @@ -0,0 +1,53 @@ +package com.easyagents.rag.retrieval; + +import java.util.HashMap; +import java.util.Map; + +public class RagQuery { + + private String query; + private RetrievalMode retrievalMode = RetrievalMode.HYBRID; + private Integer topK = 10; + private Double minScore; + private Map attributes = new HashMap(); + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public RetrievalMode getRetrievalMode() { + return retrievalMode; + } + + public void setRetrievalMode(RetrievalMode retrievalMode) { + this.retrievalMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode; + } + + public Integer getTopK() { + return topK; + } + + public void setTopK(Integer topK) { + this.topK = topK == null || topK <= 0 ? 10 : topK; + } + + public Double getMinScore() { + return minScore; + } + + public void setMinScore(Double minScore) { + this.minScore = minScore; + } + + public Map getAttributes() { + return attributes; + } + + public void setAttributes(Map attributes) { + this.attributes = attributes == null ? new HashMap() : attributes; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalExecutor.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalExecutor.java new file mode 100644 index 0000000..d8c45a4 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalExecutor.java @@ -0,0 +1,168 @@ +package com.easyagents.rag.retrieval; + +import com.easyagents.core.document.Document; +import com.easyagents.core.model.rerank.RerankModel; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public class RagRetrievalExecutor { + + private final VectorRetriever vectorRetriever; + private final KeywordRetriever keywordRetriever; + private final FusionStrategy fusionStrategy; + + public RagRetrievalExecutor(VectorRetriever vectorRetriever, + KeywordRetriever keywordRetriever, + FusionStrategy fusionStrategy) { + this.vectorRetriever = vectorRetriever; + this.keywordRetriever = keywordRetriever; + this.fusionStrategy = fusionStrategy == null ? new RrfFusionStrategy() : fusionStrategy; + } + + public RagRetrievalResult retrieve(RagQuery query) { + RagQuery effectiveQuery = query == null ? new RagQuery() : query; + RetrievalMode mode = effectiveQuery.getRetrievalMode() == null + ? RetrievalMode.HYBRID + : effectiveQuery.getRetrievalMode(); + if (mode == RetrievalMode.VECTOR) { + return buildResult(normalizeHits(vectorRetriever == null + ? Collections.emptyList() + : vectorRetriever.retrieve(effectiveQuery), HitSource.VECTOR), effectiveQuery.getTopK()); + } + if (mode == RetrievalMode.KEYWORD) { + return buildResult(normalizeHits(keywordRetriever == null + ? Collections.emptyList() + : keywordRetriever.retrieve(effectiveQuery), HitSource.KEYWORD), effectiveQuery.getTopK()); + } + + CompletableFuture> vectorFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier>() { + @Override + public List get() { + return vectorRetriever == null ? Collections.emptyList() : vectorRetriever.retrieve(effectiveQuery); + } + }); + CompletableFuture> keywordFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier>() { + @Override + public List get() { + return keywordRetriever == null ? Collections.emptyList() : keywordRetriever.retrieve(effectiveQuery); + } + }); + + List vectorHits = normalizeHits(vectorFuture.join(), HitSource.VECTOR); + List keywordHits = normalizeHits(keywordFuture.join(), HitSource.KEYWORD); + return buildResult(fusionStrategy.fuse(vectorHits, keywordHits, effectiveQuery.getTopK()), effectiveQuery.getTopK()); + } + + public RagRetrievalResult rerank(String query, List hits, RerankModel rerankModel, int topK) { + if (rerankModel == null || hits == null || hits.isEmpty()) { + return buildResult(hits, topK); + } + List rerankInput = new ArrayList(hits.size()); + Map hitMap = new LinkedHashMap(); + for (RagHit hit : hits) { + if (hit == null || hit.getDocumentId() == null) { + continue; + } + rerankInput.add(hit.toDocument()); + hitMap.put(String.valueOf(hit.getDocumentId()), hit.copy()); + } + List rerankedDocuments = rerankModel.rerank(query, rerankInput); + List rerankedHits = new ArrayList(); + if (rerankedDocuments != null) { + for (Document rerankedDocument : rerankedDocuments) { + if (rerankedDocument == null || rerankedDocument.getId() == null) { + continue; + } + RagHit original = hitMap.get(String.valueOf(rerankedDocument.getId())); + if (original == null) { + continue; + } + original.setContent(rerankedDocument.getContent()); + original.setTitle(rerankedDocument.getTitle()); + original.setScore(rerankedDocument.getScore()); + if (rerankedDocument.getScore() != null) { + original.getMetadata().put(RagRetrievalMetadataKeys.RERANK_SCORE, rerankedDocument.getScore()); + } + rerankedHits.add(original); + } + } + sortByScore(rerankedHits); + assignRank(rerankedHits); + return buildResult(rerankedHits, topK); + } + + private RagRetrievalResult buildResult(List hits, int topK) { + List effectiveHits = hits == null ? new ArrayList() : new ArrayList(hits); + if (topK > 0 && effectiveHits.size() > topK) { + effectiveHits = new ArrayList(effectiveHits.subList(0, topK)); + } + assignRank(effectiveHits); + RagRetrievalResult result = new RagRetrievalResult(); + result.setHits(effectiveHits); + result.setTotal(effectiveHits.size()); + return result; + } + + private List normalizeHits(List hits, HitSource defaultSource) { + List normalized = new ArrayList(); + if (hits == null) { + return normalized; + } + for (RagHit hit : hits) { + if (hit == null || hit.getDocumentId() == null) { + continue; + } + RagHit copy = hit.copy(); + if (copy.getHitSource() == null) { + copy.setHitSource(defaultSource); + } + if (defaultSource == HitSource.VECTOR && copy.getVectorScore() == null) { + copy.setVectorScore(copy.getScore()); + } + if (defaultSource == HitSource.KEYWORD && copy.getKeywordScore() == null) { + copy.setKeywordScore(copy.getScore()); + } + normalized.add(copy); + } + assignRank(normalized); + return normalized; + } + + private void sortByScore(List hits) { + Collections.sort(hits, new Comparator() { + @Override + public int compare(RagHit left, RagHit right) { + return Double.compare(right.getScore() == null ? 0D : right.getScore(), + left.getScore() == null ? 0D : left.getScore()); + } + }); + } + + private void assignRank(List hits) { + for (int i = 0; i < hits.size(); i++) { + RagHit hit = hits.get(i); + hit.setRank(i + 1); + hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1); + if (hit.getHitSource() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource().name()); + } + if (hit.getVectorScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore()); + } + if (hit.getKeywordScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore()); + } + if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE) + && hit.getHitSource() == HitSource.BOTH + && hit.getScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore()); + } + } + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalMetadataKeys.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalMetadataKeys.java new file mode 100644 index 0000000..bcdce17 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalMetadataKeys.java @@ -0,0 +1,16 @@ +package com.easyagents.rag.retrieval; + +public final class RagRetrievalMetadataKeys { + + private RagRetrievalMetadataKeys() { + } + + public static final String HIT_SOURCE = "hitSource"; + public static final String VECTOR_SCORE = "vectorScore"; + public static final String KEYWORD_SCORE = "keywordScore"; + public static final String FUSION_SCORE = "fusionScore"; + public static final String RERANK_SCORE = "rerankScore"; + public static final String VECTOR_RANK = "vectorRank"; + public static final String KEYWORD_RANK = "keywordRank"; + public static final String FINAL_RANK = "rank"; +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalResult.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalResult.java new file mode 100644 index 0000000..33acdd3 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagRetrievalResult.java @@ -0,0 +1,26 @@ +package com.easyagents.rag.retrieval; + +import java.util.ArrayList; +import java.util.List; + +public class RagRetrievalResult { + + private List hits = new ArrayList(); + private Integer total; + + public List getHits() { + return hits; + } + + public void setHits(List hits) { + this.hits = hits == null ? new ArrayList() : hits; + } + + public Integer getTotal() { + return total; + } + + public void setTotal(Integer total) { + this.total = total; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagScoreNormalizer.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagScoreNormalizer.java new file mode 100644 index 0000000..44dc243 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RagScoreNormalizer.java @@ -0,0 +1,117 @@ +package com.easyagents.rag.retrieval; + +import com.easyagents.core.document.Document; + +import java.util.ArrayList; +import java.util.List; + +public final class RagScoreNormalizer { + + private RagScoreNormalizer() { + } + + public static void normalize(List documents, RetrievalMode retrievalMode, boolean reranked) { + if (documents == null || documents.isEmpty()) { + return; + } + if (reranked) { + normalizeRerankScores(documents); + return; + } + + RetrievalMode effectiveMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode; + if (effectiveMode == RetrievalMode.VECTOR) { + for (Document document : documents) { + document.setScore(clamp01(readRawScore(document, RagRetrievalMetadataKeys.VECTOR_SCORE, document == null ? null : document.getScore()))); + } + return; + } + if (effectiveMode == RetrievalMode.KEYWORD) { + for (Document document : documents) { + Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.KEYWORD_SCORE, document == null ? null : document.getScore()); + if (rawScore == null || rawScore <= 0D) { + document.setScore(0D); + } else { + document.setScore(clamp01(rawScore / (1D + rawScore))); + } + } + return; + } + + double maxHybridRrfScore = 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D); + for (Document document : documents) { + Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.FUSION_SCORE, document == null ? null : document.getScore()); + if (rawScore == null || rawScore <= 0D) { + document.setScore(0D); + } else { + document.setScore(clamp01(rawScore / maxHybridRrfScore)); + } + } + } + + private static void normalizeRerankScores(List documents) { + List rawScores = new ArrayList(documents.size()); + boolean allPresent = true; + Double min = null; + Double max = null; + for (Document document : documents) { + Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.RERANK_SCORE, document == null ? null : document.getScore()); + rawScores.add(rawScore); + if (rawScore == null) { + allPresent = false; + continue; + } + min = min == null ? rawScore : Math.min(min, rawScore); + max = max == null ? rawScore : Math.max(max, rawScore); + } + + if (allPresent && min != null && max != null && Double.compare(max, min) != 0) { + for (int i = 0; i < documents.size(); i++) { + Double rawScore = rawScores.get(i); + documents.get(i).setScore(clamp01((rawScore - min) / (max - min))); + } + return; + } + + if (documents.size() == 1) { + documents.get(0).setScore(1D); + return; + } + + int size = documents.size(); + for (int i = 0; i < size; i++) { + documents.get(i).setScore(clamp01(1D - ((double) i / (double) (size - 1)))); + } + } + + private static Double readRawScore(Document document, String metadataKey, Double fallback) { + if (document == null) { + return null; + } + Object metadataValue = document.getMetadata(metadataKey); + if (metadataValue instanceof Number) { + return ((Number) metadataValue).doubleValue(); + } + if (metadataValue instanceof String) { + try { + return Double.valueOf((String) metadataValue); + } catch (NumberFormatException ignore) { + return fallback; + } + } + return fallback; + } + + private static double clamp01(Double value) { + if (value == null || value.isNaN() || value.isInfinite()) { + return 0D; + } + if (value < 0D) { + return 0D; + } + if (value > 1D) { + return 1D; + } + return value; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RetrievalMode.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RetrievalMode.java new file mode 100644 index 0000000..2663883 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RetrievalMode.java @@ -0,0 +1,19 @@ +package com.easyagents.rag.retrieval; + +public enum RetrievalMode { + VECTOR, + KEYWORD, + HYBRID; + + public static RetrievalMode from(String value) { + if (value == null || value.trim().isEmpty()) { + return HYBRID; + } + for (RetrievalMode mode : values()) { + if (mode.name().equalsIgnoreCase(value.trim())) { + return mode; + } + } + throw new IllegalArgumentException("Unsupported retrieval mode: " + value); + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RrfFusionStrategy.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RrfFusionStrategy.java new file mode 100644 index 0000000..d2b3297 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/RrfFusionStrategy.java @@ -0,0 +1,121 @@ +package com.easyagents.rag.retrieval; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +public class RrfFusionStrategy implements FusionStrategy { + + public static final int DEFAULT_RRF_K = 60; + + private final int rrfK; + + public RrfFusionStrategy() { + this(DEFAULT_RRF_K); + } + + public RrfFusionStrategy(int rrfK) { + this.rrfK = rrfK <= 0 ? DEFAULT_RRF_K : rrfK; + } + + @Override + public List fuse(List vectorHits, List keywordHits, int topK) { + Map mergedHits = new LinkedHashMap(); + Map fusionScores = new LinkedHashMap(); + + mergeHits(mergedHits, fusionScores, vectorHits, HitSource.VECTOR); + mergeHits(mergedHits, fusionScores, keywordHits, HitSource.KEYWORD); + + List results = new ArrayList(mergedHits.values()); + Collections.sort(results, new Comparator() { + @Override + public int compare(RagHit left, RagHit right) { + int scoreCompare = Double.compare(nullSafeScore(right.getScore()), nullSafeScore(left.getScore())); + if (scoreCompare != 0) { + return scoreCompare; + } + return String.valueOf(left.getDocumentId()).compareTo(String.valueOf(right.getDocumentId())); + } + }); + assignRank(results); + if (topK > 0 && results.size() > topK) { + return new ArrayList(results.subList(0, topK)); + } + return results; + } + + private void mergeHits(Map mergedHits, + Map fusionScores, + List incomingHits, + HitSource source) { + if (incomingHits == null) { + return; + } + for (int index = 0; index < incomingHits.size(); index++) { + RagHit incoming = incomingHits.get(index); + if (incoming == null || incoming.getDocumentId() == null) { + continue; + } + String key = String.valueOf(incoming.getDocumentId()); + RagHit merged = mergedHits.get(key); + if (merged == null) { + merged = incoming.copy(); + merged.setHitSource(source); + mergedHits.put(key, merged); + fusionScores.put(key, 0D); + } else { + mergeBaseFields(merged, incoming); + merged.setHitSource(HitSource.merge(merged.getHitSource(), source)); + } + + int rank = index + 1; + if (source == HitSource.VECTOR) { + merged.setVectorScore(incoming.getVectorScore() != null ? incoming.getVectorScore() : incoming.getScore()); + merged.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_RANK, rank); + } else { + merged.setKeywordScore(incoming.getKeywordScore() != null ? incoming.getKeywordScore() : incoming.getScore()); + merged.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_RANK, rank); + } + double fusedScore = fusionScores.get(key) + (1D / (rrfK + rank)); + fusionScores.put(key, fusedScore); + merged.setScore(fusedScore); + } + } + + private void mergeBaseFields(RagHit target, RagHit incoming) { + if ((target.getTitle() == null || target.getTitle().isEmpty()) && incoming.getTitle() != null) { + target.setTitle(incoming.getTitle()); + } + if ((target.getContent() == null || target.getContent().isEmpty()) && incoming.getContent() != null) { + target.setContent(incoming.getContent()); + } + if (incoming.getMetadata() != null && !incoming.getMetadata().isEmpty()) { + target.getMetadata().putAll(incoming.getMetadata()); + } + } + + private void assignRank(List results) { + for (int i = 0; i < results.size(); i++) { + RagHit hit = results.get(i); + hit.setRank(i + 1); + hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1); + hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource() == null ? null : hit.getHitSource().name()); + if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE) && hit.getScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore()); + } + if (hit.getVectorScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore()); + } + if (hit.getKeywordScore() != null) { + hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore()); + } + } + } + + private double nullSafeScore(Double value) { + return value == null ? 0D : value; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/VectorRetriever.java b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/VectorRetriever.java new file mode 100644 index 0000000..fb668fa --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/main/java/com/easyagents/rag/retrieval/VectorRetriever.java @@ -0,0 +1,8 @@ +package com.easyagents.rag.retrieval; + +import java.util.List; + +public interface VectorRetriever { + + List retrieve(RagQuery query); +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagRetrievalExecutorTest.java b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagRetrievalExecutorTest.java new file mode 100644 index 0000000..a1158af --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagRetrievalExecutorTest.java @@ -0,0 +1,111 @@ +package com.easyagents.rag.retrieval; + +import com.easyagents.core.document.Document; +import com.easyagents.core.model.rerank.RerankModel; +import com.easyagents.core.model.rerank.RerankOptions; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class RagRetrievalExecutorTest { + + @Test + public void shouldOnlyCallVectorRetrieverInVectorMode() { + final List invoked = new ArrayList(); + RagRetrievalExecutor executor = new RagRetrievalExecutor( + new VectorRetriever() { + @Override + public List retrieve(RagQuery query) { + invoked.add("vector"); + return Arrays.asList(hit("1", "vector", 0.9D, HitSource.VECTOR)); + } + }, + new KeywordRetriever() { + @Override + public List retrieve(RagQuery query) { + invoked.add("keyword"); + return Arrays.asList(hit("2", "keyword", 5D, HitSource.KEYWORD)); + } + }, + new RrfFusionStrategy() + ); + + RagQuery query = new RagQuery(); + query.setRetrievalMode(RetrievalMode.VECTOR); + query.setTopK(10); + + RagRetrievalResult result = executor.retrieve(query); + + Assert.assertEquals(Arrays.asList("vector"), invoked); + Assert.assertEquals(1, result.getHits().size()); + Assert.assertEquals(HitSource.VECTOR, result.getHits().get(0).getHitSource()); + } + + @Test + public void shouldRerankAfterHybridFusionAndKeepHitSource() { + RagRetrievalExecutor executor = new RagRetrievalExecutor( + new VectorRetriever() { + @Override + public List retrieve(RagQuery query) { + return Arrays.asList( + hit("1", "alpha", 0.9D, HitSource.VECTOR), + hit("2", "beta", 0.8D, HitSource.VECTOR) + ); + } + }, + new KeywordRetriever() { + @Override + public List retrieve(RagQuery query) { + return Arrays.asList( + hit("2", "beta", 4.1D, HitSource.KEYWORD), + hit("3", "gamma", 3.9D, HitSource.KEYWORD) + ); + } + }, + new RrfFusionStrategy() + ); + + RagQuery query = new RagQuery(); + query.setRetrievalMode(RetrievalMode.HYBRID); + query.setTopK(10); + + RagRetrievalResult retrieved = executor.retrieve(query); + RagRetrievalResult reranked = executor.rerank("query", retrieved.getHits(), new ReverseRerankModel(), 10); + + Assert.assertEquals("3", String.valueOf(reranked.getHits().get(0).getDocumentId())); + Assert.assertEquals("2", String.valueOf(reranked.getHits().get(2).getDocumentId())); + Assert.assertEquals(HitSource.BOTH, reranked.getHits().get(2).getHitSource()); + } + + private RagHit hit(String id, String content, double score, HitSource source) { + RagHit hit = new RagHit(); + hit.setDocumentId(id); + hit.setContent(content); + hit.setScore(score); + hit.setHitSource(source); + if (source == HitSource.VECTOR) { + hit.setVectorScore(score); + } else if (source == HitSource.KEYWORD) { + hit.setKeywordScore(score); + } + return hit; + } + + private static class ReverseRerankModel implements RerankModel { + + @Override + public List rerank(String query, List documents, RerankOptions options) { + List result = new ArrayList(); + double score = documents.size(); + for (int i = documents.size() - 1; i >= 0; i--) { + Document document = documents.get(i); + document.setScore(score--); + result.add(document); + } + return result; + } + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagScoreNormalizerTest.java b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagScoreNormalizerTest.java new file mode 100644 index 0000000..ede4414 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RagScoreNormalizerTest.java @@ -0,0 +1,69 @@ +package com.easyagents.rag.retrieval; + +import com.easyagents.core.document.Document; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +public class RagScoreNormalizerTest { + + @Test + public void shouldNormalizeKeywordScoresToZeroAndOneRange() { + Document first = document(1, 9D, RagRetrievalMetadataKeys.KEYWORD_SCORE); + Document second = document(2, 0D, RagRetrievalMetadataKeys.KEYWORD_SCORE); + + RagScoreNormalizer.normalize(Arrays.asList(first, second), RetrievalMode.KEYWORD, false); + + Assert.assertEquals(0.9D, first.getScore(), 0.0001D); + Assert.assertEquals(0D, second.getScore(), 0.0001D); + } + + @Test + public void shouldNormalizeHybridFusionScoreByRrfUpperBound() { + Document document = document(1, 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D), RagRetrievalMetadataKeys.FUSION_SCORE); + + RagScoreNormalizer.normalize(Arrays.asList(document), RetrievalMode.HYBRID, false); + + Assert.assertEquals(1D, document.getScore(), 0.0001D); + } + + @Test + public void shouldNormalizeRerankScoresByMinMax() { + List documents = Arrays.asList( + document(1, 10D, RagRetrievalMetadataKeys.RERANK_SCORE), + document(2, 20D, RagRetrievalMetadataKeys.RERANK_SCORE), + document(3, 30D, RagRetrievalMetadataKeys.RERANK_SCORE) + ); + + RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true); + + Assert.assertEquals(0D, documents.get(0).getScore(), 0.0001D); + Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D); + Assert.assertEquals(1D, documents.get(2).getScore(), 0.0001D); + } + + @Test + public void shouldFallbackToRankBasedNormalizationWhenRerankScoresAreEqual() { + List documents = Arrays.asList( + document(1, 5D, RagRetrievalMetadataKeys.RERANK_SCORE), + document(2, 5D, RagRetrievalMetadataKeys.RERANK_SCORE), + document(3, 5D, RagRetrievalMetadataKeys.RERANK_SCORE) + ); + + RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true); + + Assert.assertEquals(1D, documents.get(0).getScore(), 0.0001D); + Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D); + Assert.assertEquals(0D, documents.get(2).getScore(), 0.0001D); + } + + private Document document(Object id, Double score, String metadataKey) { + Document document = new Document(); + document.setId(id); + document.setScore(score); + document.addMetadata(metadataKey, score); + return document; + } +} diff --git a/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RrfFusionStrategyTest.java b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RrfFusionStrategyTest.java new file mode 100644 index 0000000..ae3b010 --- /dev/null +++ b/easy-agents-rag/easy-agents-rag-retrieval/src/test/java/com/easyagents/rag/retrieval/RrfFusionStrategyTest.java @@ -0,0 +1,46 @@ +package com.easyagents.rag.retrieval; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +public class RrfFusionStrategyTest { + + @Test + public void shouldMarkDuplicateHitsAsBothAndRankFirst() { + RrfFusionStrategy strategy = new RrfFusionStrategy(); + + RagHit vectorOnly = hit("1", 0.91D, HitSource.VECTOR); + RagHit bothVector = hit("2", 0.88D, HitSource.VECTOR); + RagHit bothKeyword = hit("2", 5.2D, HitSource.KEYWORD); + RagHit keywordOnly = hit("3", 4.8D, HitSource.KEYWORD); + + List fused = strategy.fuse( + Arrays.asList(vectorOnly, bothVector), + Arrays.asList(bothKeyword, keywordOnly), + 10 + ); + + Assert.assertEquals(3, fused.size()); + Assert.assertEquals("2", String.valueOf(fused.get(0).getDocumentId())); + Assert.assertEquals(HitSource.BOTH, fused.get(0).getHitSource()); + Assert.assertEquals("1", String.valueOf(fused.get(1).getDocumentId())); + Assert.assertEquals("3", String.valueOf(fused.get(2).getDocumentId())); + } + + private RagHit hit(String id, double score, HitSource source) { + RagHit hit = new RagHit(); + hit.setDocumentId(id); + hit.setContent("doc-" + id); + hit.setScore(score); + hit.setHitSource(source); + if (source == HitSource.VECTOR) { + hit.setVectorScore(score); + } else if (source == HitSource.KEYWORD) { + hit.setKeywordScore(score); + } + return hit; + } +} diff --git a/easy-agents-search-engine/easy-agents-search-engine-es/pom.xml b/easy-agents-search-engine/easy-agents-search-engine-es/pom.xml index 6f18ae3..3f17e44 100644 --- a/easy-agents-search-engine/easy-agents-search-engine-es/pom.xml +++ b/easy-agents-search-engine/easy-agents-search-engine-es/pom.xml @@ -37,6 +37,11 @@ jackson-databind 2.15.2 + + junit + junit + test + diff --git a/easy-agents-search-engine/easy-agents-search-engine-es/src/main/java/com/easyagents/engine/es/ElasticSearcher.java b/easy-agents-search-engine/easy-agents-search-engine-es/src/main/java/com/easyagents/engine/es/ElasticSearcher.java index ab9b0c6..7f7f07a 100644 --- a/easy-agents-search-engine/easy-agents-search-engine-es/src/main/java/com/easyagents/engine/es/ElasticSearcher.java +++ b/easy-agents-search-engine/easy-agents-search-engine-es/src/main/java/com/easyagents/engine/es/ElasticSearcher.java @@ -4,12 +4,15 @@ import co.elastic.clients.elasticsearch.ElasticsearchClient; import co.elastic.clients.elasticsearch.core.*; import co.elastic.clients.elasticsearch.core.bulk.BulkOperation; import co.elastic.clients.elasticsearch.core.bulk.IndexOperation; +import co.elastic.clients.elasticsearch.core.search.SourceConfig; import co.elastic.clients.json.JsonData; import co.elastic.clients.json.jackson.JacksonJsonpMapper; import co.elastic.clients.transport.ElasticsearchTransport; import co.elastic.clients.transport.rest_client.RestClientTransport; import com.easyagents.core.document.Document; import com.easyagents.search.engine.service.DocumentSearcher; +import com.easyagents.search.engine.service.KeywordSearchMetadataKeys; +import com.easyagents.search.engine.service.KeywordSearchRequest; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; import org.apache.http.auth.UsernamePasswordCredentials; @@ -88,13 +91,7 @@ public class ElasticSearcher implements DocumentSearcher { transport = new RestClientTransport(restClient, new JacksonJsonpMapper()); ElasticsearchClient client = new ElasticsearchClient(transport); - Map source = new HashMap<>(); - source.put("id", document.getId()); - source.put("content", document.getContent()); - if (document.getTitle() != null) { - source.put("title", document.getTitle()); - } - + Map source = buildSource(document); String documentId = document.getId().toString(); IndexOperation indexOp = IndexOperation.of(i -> i .index(esConfig.getIndexName()) @@ -116,7 +113,7 @@ public class ElasticSearcher implements DocumentSearcher { } @Override - public List searchDocuments(String keyword, int count) { + public List searchDocuments(KeywordSearchRequest request) { RestClient restClient = null; ElasticsearchTransport transport = null; @@ -125,21 +122,16 @@ public class ElasticSearcher implements DocumentSearcher { transport = new RestClientTransport(restClient, new JacksonJsonpMapper()); ElasticsearchClient client = new ElasticsearchClient(transport); - SearchRequest request = SearchRequest.of(s -> s - .index(esConfig.getIndexName()) - .size(count) - .query(q -> q - .match(m -> m - .field("title") - .field("content") - .query(keyword) - ) - ) - ); - - SearchResponse response = client.search(request, Document.class); + SearchResponse response = client.search(buildSearchRequest(request), Map.class); List results = new ArrayList<>(); - response.hits().hits().forEach(hit -> results.add(hit.source())); + response.hits().hits().forEach(hit -> { + Map source = hit.source(); + Document document = toDocument(hit.id(), source, hit.score()); + if (document == null) { + return; + } + results.add(document); + }); return results; } catch (Exception e) { @@ -193,14 +185,17 @@ public class ElasticSearcher implements DocumentSearcher { transport = new RestClientTransport(restClient, new JacksonJsonpMapper()); ElasticsearchClient client = new ElasticsearchClient(transport); - UpdateRequest request = UpdateRequest.of(u -> u + UpdateRequest, Map> request = UpdateRequest.of(u -> u .index(esConfig.getIndexName()) .id(document.getId().toString()) - .doc(document) + .doc(buildSource(document)) ); - UpdateResponse response = client.update(request, Object.class); - return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated; + @SuppressWarnings("unchecked") + Class> documentClass = (Class>) (Class) Map.class; + UpdateResponse> response = client.update(request, documentClass); + return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated + || response.result() == co.elastic.clients.elasticsearch._types.Result.NoOp; } catch (Exception e) { LOG.error("Error updating document with id: " + document.getId(), e); return false; @@ -220,4 +215,88 @@ public class ElasticSearcher implements DocumentSearcher { } } } + + @SuppressWarnings("unchecked") + private Document toDocument(String hitId, Map source, Double score) { + if (source == null || source.isEmpty()) { + return null; + } + + Document document = new Document(); + Object id = source.get("id"); + document.setId(id != null ? id : hitId); + + Object title = source.get("title"); + if (title != null) { + document.setTitle(String.valueOf(title)); + } + + Object content = source.get("content"); + if (content != null) { + document.setContent(String.valueOf(content)); + } + + Object metadataMap = source.get("metadataMap"); + if (metadataMap instanceof Map) { + document.setMetadataMap(new HashMap<>((Map) metadataMap)); + } + + document.setScore(score); + return document; + } + + Map buildSource(Document document) { + Map source = new HashMap(); + source.put("id", document.getId()); + source.put("content", document.getContent()); + if (document.getTitle() != null) { + source.put("title", document.getTitle()); + } + if (document.getMetadataMap() != null && !document.getMetadataMap().isEmpty()) { + source.put("metadataMap", new HashMap(document.getMetadataMap())); + Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID); + if (knowledgeId != null) { + source.put(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId)); + } + } + return source; + } + + SearchRequest buildSearchRequest(KeywordSearchRequest request) { + KeywordSearchRequest effectiveRequest = request == null ? new KeywordSearchRequest() : request; + return SearchRequest.of(s -> s + .index(esConfig.getIndexName()) + .size(effectiveRequest.getCount()) + .source(SourceConfig.of(sc -> sc.filter(f -> f.includes("id", "title", "content", "metadataMap")))) + .query(q -> q.bool(b -> { + b.must(m -> m.multiMatch(mm -> mm + .query(effectiveRequest.getKeyword()) + .fields("title", "content") + )); + if (effectiveRequest.getKnowledgeId() != null && !effectiveRequest.getKnowledgeId().trim().isEmpty()) { + b.filter(f -> f.term(t -> t + .field(KeywordSearchMetadataKeys.KNOWLEDGE_ID) + .value(v -> v.stringValue(effectiveRequest.getKnowledgeId().trim())) + )); + } + return b; + })) + ); + } + + public boolean checkAvailable() { + RestClient restClient = null; + ElasticsearchTransport transport = null; + try { + restClient = buildRestClient(); + transport = new RestClientTransport(restClient, new JacksonJsonpMapper()); + ElasticsearchClient client = new ElasticsearchClient(transport); + return client.info() != null; + } catch (Exception e) { + LOG.error("Elasticsearch availability check failed", e); + return false; + } finally { + closeResources(transport, restClient); + } + } } diff --git a/easy-agents-search-engine/easy-agents-search-engine-es/src/test/java/com/easyagents/engine/es/ElasticSearcherQueryBuilderTest.java b/easy-agents-search-engine/easy-agents-search-engine-es/src/test/java/com/easyagents/engine/es/ElasticSearcherQueryBuilderTest.java new file mode 100644 index 0000000..31b929d --- /dev/null +++ b/easy-agents-search-engine/easy-agents-search-engine-es/src/test/java/com/easyagents/engine/es/ElasticSearcherQueryBuilderTest.java @@ -0,0 +1,54 @@ +package com.easyagents.engine.es; + +import co.elastic.clients.elasticsearch.core.SearchRequest; +import com.easyagents.core.document.Document; +import com.easyagents.search.engine.service.KeywordSearchMetadataKeys; +import com.easyagents.search.engine.service.KeywordSearchRequest; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Map; + +public class ElasticSearcherQueryBuilderTest { + + @Test + public void shouldBuildSearchRequestWithMultiMatchAndKnowledgeFilter() { + ElasticSearcher searcher = new ElasticSearcher(config()); + KeywordSearchRequest request = KeywordSearchRequest.of("客服", 5); + request.setKnowledgeId("100"); + + SearchRequest searchRequest = searcher.buildSearchRequest(request); + + Assert.assertEquals(5, searchRequest.size().intValue()); + Assert.assertNotNull(searchRequest.query().bool()); + Assert.assertEquals(1, searchRequest.query().bool().must().size()); + Assert.assertNotNull(searchRequest.query().bool().must().get(0).multiMatch()); + Assert.assertEquals(2, searchRequest.query().bool().must().get(0).multiMatch().fields().size()); + Assert.assertEquals(1, searchRequest.query().bool().filter().size()); + Assert.assertEquals("knowledgeId", searchRequest.query().bool().filter().get(0).term().field()); + } + + @Test + public void shouldExtractKnowledgeIdToTopLevelSource() { + ElasticSearcher searcher = new ElasticSearcher(config()); + Document document = new Document(); + document.setId("1"); + document.setTitle("title"); + document.setContent("content"); + document.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100"); + + Map source = searcher.buildSource(document); + + Assert.assertEquals("100", source.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID)); + Assert.assertTrue(source.get("metadataMap") instanceof Map); + } + + private ESConfig config() { + ESConfig config = new ESConfig(); + config.setHost("http://127.0.0.1:9200"); + config.setUserName("elastic"); + config.setPassword("elastic"); + config.setIndexName("easyflow"); + return config; + } +} diff --git a/easy-agents-search-engine/easy-agents-search-engine-lucene/pom.xml b/easy-agents-search-engine/easy-agents-search-engine-lucene/pom.xml index 62623ce..a95d57e 100644 --- a/easy-agents-search-engine/easy-agents-search-engine-lucene/pom.xml +++ b/easy-agents-search-engine/easy-agents-search-engine-lucene/pom.xml @@ -51,5 +51,10 @@ com.easyagents easy-agents-search-engine-service + + junit + junit + test + diff --git a/easy-agents-search-engine/easy-agents-search-engine-lucene/src/main/java/com/easyagents/search/engine/lucene/LuceneSearcher.java b/easy-agents-search-engine/easy-agents-search-engine-lucene/src/main/java/com/easyagents/search/engine/lucene/LuceneSearcher.java index ab246ab..5edd49e 100644 --- a/easy-agents-search-engine/easy-agents-search-engine-lucene/src/main/java/com/easyagents/search/engine/lucene/LuceneSearcher.java +++ b/easy-agents-search-engine/easy-agents-search-engine-lucene/src/main/java/com/easyagents/search/engine/lucene/LuceneSearcher.java @@ -17,6 +17,8 @@ package com.easyagents.search.engine.lucene; import com.easyagents.core.document.Document; import com.easyagents.search.engine.service.DocumentSearcher; +import com.easyagents.search.engine.service.KeywordSearchMetadataKeys; +import com.easyagents.search.engine.service.KeywordSearchRequest; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Field; import org.apache.lucene.document.StringField; @@ -78,7 +80,7 @@ public class LuceneSearcher implements DocumentSearcher { if (document.getTitle() != null) { luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES)); } - + appendKnowledgeId(document, luceneDoc); indexWriter.addDocument(luceneDoc); indexWriter.commit(); @@ -127,7 +129,7 @@ public class LuceneSearcher implements DocumentSearcher { if (document.getTitle() != null) { luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES)); } - + appendKnowledgeId(document, luceneDoc); indexWriter.updateDocument(term, luceneDoc); indexWriter.commit(); return true; @@ -140,18 +142,21 @@ public class LuceneSearcher implements DocumentSearcher { } @Override - public List searchDocuments(String keyword, int count) { + public List searchDocuments(KeywordSearchRequest request) { List results = new ArrayList<>(); try (IndexReader reader = DirectoryReader.open(directory)) { IndexSearcher searcher = new IndexSearcher(reader); - Query query = buildQuery(keyword); - TopDocs topDocs = searcher.search(query, count); + Query query = buildQuery(request); + TopDocs topDocs = searcher.search(query, request == null ? 10 : request.getCount()); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { org.apache.lucene.document.Document doc = searcher.doc(scoreDoc.doc); Document resultDoc = new Document(); resultDoc.setId(doc.get("id")); resultDoc.setContent(doc.get("content")); resultDoc.setTitle(doc.get("title")); + if (doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID) != null) { + resultDoc.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID)); + } resultDoc.setScore((double) scoreDoc.score); @@ -164,9 +169,10 @@ public class LuceneSearcher implements DocumentSearcher { return results; } - private static Query buildQuery(String keyword) { + Query buildQuery(KeywordSearchRequest request) { try { Analyzer analyzer = createAnalyzer(); + String keyword = request == null ? null : request.getKeyword(); QueryParser titleQueryParser = new QueryParser("title", analyzer); Query titleQuery = titleQueryParser.parse(keyword); @@ -179,6 +185,9 @@ public class LuceneSearcher implements DocumentSearcher { BooleanQuery.Builder builder = new BooleanQuery.Builder(); builder.add(titleBooleanClause) .add(contentBooleanClause); + if (request != null && request.getKnowledgeId() != null && !request.getKnowledgeId().trim().isEmpty()) { + builder.add(new TermQuery(new Term(KeywordSearchMetadataKeys.KNOWLEDGE_ID, request.getKnowledgeId().trim())), BooleanClause.Occur.MUST); + } return builder.build(); } catch (ParseException e) { LOG.error(e.toString(), e); @@ -200,6 +209,16 @@ public class LuceneSearcher implements DocumentSearcher { return new JcsegAnalyzer(ISegment.Type.NLP, config, DictionaryFactory.createSingletonDictionary(config)); } + private void appendKnowledgeId(Document document, org.apache.lucene.document.Document luceneDoc) { + if (document == null || document.getMetadataMap() == null) { + return; + } + Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID); + if (knowledgeId != null) { + luceneDoc.add(new StringField(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId), Field.Store.YES)); + } + } + public void close(IndexWriter indexWriter) { try { if (indexWriter != null) { diff --git a/easy-agents-search-engine/easy-agents-search-engine-lucene/src/test/java/com/easyagents/search/engine/lucene/LuceneSearcherTest.java b/easy-agents-search-engine/easy-agents-search-engine-lucene/src/test/java/com/easyagents/search/engine/lucene/LuceneSearcherTest.java new file mode 100644 index 0000000..4f21723 --- /dev/null +++ b/easy-agents-search-engine/easy-agents-search-engine-lucene/src/test/java/com/easyagents/search/engine/lucene/LuceneSearcherTest.java @@ -0,0 +1,45 @@ +package com.easyagents.search.engine.lucene; + +import com.easyagents.core.document.Document; +import com.easyagents.search.engine.service.KeywordSearchMetadataKeys; +import com.easyagents.search.engine.service.KeywordSearchRequest; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +public class LuceneSearcherTest { + + @Test + public void shouldFilterByKnowledgeIdAndSearchTitleAndContent() throws Exception { + Path tempDir = Files.createTempDirectory("lucene-searcher-test"); + LuceneConfig config = new LuceneConfig(); + config.setIndexDirPath(tempDir.toString()); + LuceneSearcher searcher = new LuceneSearcher(config); + + Document first = new Document(); + first.setId("1"); + first.setTitle("客服标题"); + first.setContent("这里没有关键字"); + first.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100"); + + Document second = new Document(); + second.setId("2"); + second.setTitle("别的知识库"); + second.setContent("客服内容"); + second.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "200"); + + Assert.assertTrue(searcher.addDocument(first)); + Assert.assertTrue(searcher.addDocument(second)); + + KeywordSearchRequest request = KeywordSearchRequest.of("客服", 10); + request.setKnowledgeId("100"); + List results = searcher.searchDocuments(request); + + Assert.assertEquals(1, results.size()); + Assert.assertEquals("1", String.valueOf(results.get(0).getId())); + Assert.assertEquals("100", String.valueOf(results.get(0).getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID))); + } +} diff --git a/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/DocumentSearcher.java b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/DocumentSearcher.java index 99b4cdc..00e84c6 100644 --- a/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/DocumentSearcher.java +++ b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/DocumentSearcher.java @@ -28,8 +28,12 @@ public interface DocumentSearcher { boolean updateDocument(Document document); default List searchDocuments(String keyword) { - return searchDocuments(keyword, 10); + return searchDocuments(KeywordSearchRequest.of(keyword, 10)); } - List searchDocuments(String keyword, int count); + default List searchDocuments(String keyword, int count) { + return searchDocuments(KeywordSearchRequest.of(keyword, count)); + } + + List searchDocuments(KeywordSearchRequest request); } diff --git a/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchMetadataKeys.java b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchMetadataKeys.java new file mode 100644 index 0000000..97ffc79 --- /dev/null +++ b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchMetadataKeys.java @@ -0,0 +1,9 @@ +package com.easyagents.search.engine.service; + +public final class KeywordSearchMetadataKeys { + + private KeywordSearchMetadataKeys() { + } + + public static final String KNOWLEDGE_ID = "knowledgeId"; +} diff --git a/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchRequest.java b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchRequest.java new file mode 100644 index 0000000..c68e2a9 --- /dev/null +++ b/easy-agents-search-engine/easy-agents-search-engine-service/src/main/java/com/easyagents/search/engine/service/KeywordSearchRequest.java @@ -0,0 +1,39 @@ +package com.easyagents.search.engine.service; + +public class KeywordSearchRequest { + + private String keyword; + private int count = 10; + private String knowledgeId; + + public static KeywordSearchRequest of(String keyword, int count) { + KeywordSearchRequest request = new KeywordSearchRequest(); + request.setKeyword(keyword); + request.setCount(count); + return request; + } + + public String getKeyword() { + return keyword; + } + + public void setKeyword(String keyword) { + this.keyword = keyword; + } + + public int getCount() { + return count; + } + + public void setCount(int count) { + this.count = count <= 0 ? 10 : count; + } + + public String getKnowledgeId() { + return knowledgeId; + } + + public void setKnowledgeId(String knowledgeId) { + this.knowledgeId = knowledgeId; + } +} diff --git a/easy-agents-store/easy-agents-store-milvus/src/main/java/com/easyagents/store/milvus/MilvusVectorStore.java b/easy-agents-store/easy-agents-store-milvus/src/main/java/com/easyagents/store/milvus/MilvusVectorStore.java index 802c773..aadd67e 100644 --- a/easy-agents-store/easy-agents-store-milvus/src/main/java/com/easyagents/store/milvus/MilvusVectorStore.java +++ b/easy-agents-store/easy-agents-store-milvus/src/main/java/com/easyagents/store/milvus/MilvusVectorStore.java @@ -535,4 +535,15 @@ public class MilvusVectorStore extends DocumentStore { public MilvusClientV2 getClient() { return client; } + + public boolean checkAvailable() { + try { + return client.hasCollection(HasCollectionReq.builder() + .collectionName("__milvus_boot_probe__") + .build()) != null; + } catch (Exception e) { + LOG.warn("Milvus availability check failed. message={}", e.getMessage()); + return false; + } + } }