feat: 下沉知识库检索编排能力
- 新增 rag retrieval 核心协议、RRF 融合与相关度归一化 - 支持关键词检索按 knowledgeId 过滤并补充 ES/Lucene 单测 - 扩展 KnowledgeNode 检索模式与 Milvus 检索参数透传
This commit is contained in:
@@ -35,6 +35,7 @@ public class KnowledgeNode extends BaseNode {
|
|||||||
private Object knowledgeId;
|
private Object knowledgeId;
|
||||||
private String keyword;
|
private String keyword;
|
||||||
private String limit;
|
private String limit;
|
||||||
|
private String retrievalMode = "HYBRID";
|
||||||
|
|
||||||
public Object getKnowledgeId() {
|
public Object getKnowledgeId() {
|
||||||
return knowledgeId;
|
return knowledgeId;
|
||||||
@@ -60,6 +61,14 @@ public class KnowledgeNode extends BaseNode {
|
|||||||
this.limit = limit;
|
this.limit = limit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getRetrievalMode() {
|
||||||
|
return retrievalMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRetrievalMode(String retrievalMode) {
|
||||||
|
this.retrievalMode = StringUtil.hasText(retrievalMode) ? retrievalMode : "HYBRID";
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> execute(Chain chain) {
|
public Map<String, Object> execute(Chain chain) {
|
||||||
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
||||||
@@ -90,6 +99,7 @@ public class KnowledgeNode extends BaseNode {
|
|||||||
"knowledgeId=" + knowledgeId +
|
"knowledgeId=" + knowledgeId +
|
||||||
", keyword='" + keyword + '\'' +
|
", keyword='" + keyword + '\'' +
|
||||||
", limit='" + limit + '\'' +
|
", limit='" + limit + '\'' +
|
||||||
|
", retrievalMode='" + retrievalMode + '\'' +
|
||||||
", parameters=" + parameters +
|
", parameters=" + parameters +
|
||||||
", outputDefs=" + outputDefs +
|
", outputDefs=" + outputDefs +
|
||||||
", id='" + id + '\'' +
|
", id='" + id + '\'' +
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ public class KnowledgeNodeParser extends BaseNodeParser<KnowledgeNode> {
|
|||||||
knowledgeNode.setKnowledgeId(data.get("knowledgeId"));
|
knowledgeNode.setKnowledgeId(data.get("knowledgeId"));
|
||||||
knowledgeNode.setLimit(data.getString("limit"));
|
knowledgeNode.setLimit(data.getString("limit"));
|
||||||
knowledgeNode.setKeyword(data.getString("keyword"));
|
knowledgeNode.setKeyword(data.getString("keyword"));
|
||||||
|
knowledgeNode.setRetrievalMode(data.getString("retrievalMode"));
|
||||||
|
|
||||||
return knowledgeNode;
|
return knowledgeNode;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,10 @@
|
|||||||
<groupId>com.easyagents</groupId>
|
<groupId>com.easyagents</groupId>
|
||||||
<artifactId>easy-agents-core</artifactId>
|
<artifactId>easy-agents-core</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.easyagents</groupId>
|
||||||
|
<artifactId>easy-agents-search-engine-service</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.easyagents</groupId>
|
<groupId>com.easyagents</groupId>
|
||||||
<artifactId>easy-agents-rag-core</artifactId>
|
<artifactId>easy-agents-rag-core</artifactId>
|
||||||
@@ -32,5 +36,10 @@
|
|||||||
<groupId>com.easyagents</groupId>
|
<groupId>com.easyagents</groupId>
|
||||||
<artifactId>easy-agents-rag-enhance</artifactId>
|
<artifactId>easy-agents-rag-enhance</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface FusionStrategy {
|
||||||
|
|
||||||
|
List<RagHit> fuse(List<RagHit> vectorHits, List<RagHit> keywordHits, int topK);
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
public enum HitSource {
|
||||||
|
VECTOR,
|
||||||
|
KEYWORD,
|
||||||
|
BOTH;
|
||||||
|
|
||||||
|
public static HitSource merge(HitSource current, HitSource incoming) {
|
||||||
|
if (current == null) {
|
||||||
|
return incoming;
|
||||||
|
}
|
||||||
|
if (incoming == null || current == incoming) {
|
||||||
|
return current;
|
||||||
|
}
|
||||||
|
return BOTH;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface KeywordRetriever {
|
||||||
|
|
||||||
|
List<RagHit> retrieve(RagQuery query);
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class RagHit {
|
||||||
|
|
||||||
|
private Object documentId;
|
||||||
|
private String title;
|
||||||
|
private String content;
|
||||||
|
private Double score;
|
||||||
|
private HitSource hitSource;
|
||||||
|
private Double vectorScore;
|
||||||
|
private Double keywordScore;
|
||||||
|
private Integer rank;
|
||||||
|
private Map<String, Object> metadata = new HashMap<String, Object>();
|
||||||
|
|
||||||
|
public static RagHit fromDocument(Document document) {
|
||||||
|
if (document == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
RagHit hit = new RagHit();
|
||||||
|
hit.setDocumentId(document.getId());
|
||||||
|
hit.setTitle(document.getTitle());
|
||||||
|
hit.setContent(document.getContent());
|
||||||
|
hit.setScore(document.getScore());
|
||||||
|
hit.setMetadata(document.getMetadataMap());
|
||||||
|
|
||||||
|
Object hitSource = hit.getMetadata().get(RagRetrievalMetadataKeys.HIT_SOURCE);
|
||||||
|
if (hitSource instanceof String) {
|
||||||
|
hit.setHitSource(HitSource.valueOf(String.valueOf(hitSource)));
|
||||||
|
}
|
||||||
|
hit.setVectorScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.VECTOR_SCORE)));
|
||||||
|
hit.setKeywordScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.KEYWORD_SCORE)));
|
||||||
|
hit.setRank(asInteger(hit.getMetadata().get(RagRetrievalMetadataKeys.FINAL_RANK)));
|
||||||
|
return hit;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static RagHit fromDocument(Document document, HitSource defaultSource) {
|
||||||
|
RagHit hit = fromDocument(document);
|
||||||
|
if (hit != null && hit.getHitSource() == null) {
|
||||||
|
hit.setHitSource(defaultSource);
|
||||||
|
}
|
||||||
|
if (hit != null && defaultSource == HitSource.VECTOR && hit.getVectorScore() == null) {
|
||||||
|
hit.setVectorScore(document.getScore());
|
||||||
|
}
|
||||||
|
if (hit != null && defaultSource == HitSource.KEYWORD && hit.getKeywordScore() == null) {
|
||||||
|
hit.setKeywordScore(document.getScore());
|
||||||
|
}
|
||||||
|
return hit;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Document toDocument() {
|
||||||
|
Document document = new Document();
|
||||||
|
document.setId(documentId);
|
||||||
|
document.setTitle(title);
|
||||||
|
document.setContent(content);
|
||||||
|
document.setScore(score);
|
||||||
|
Map<String, Object> metadataMap = metadata == null
|
||||||
|
? new HashMap<String, Object>()
|
||||||
|
: new HashMap<String, Object>(metadata);
|
||||||
|
if (hitSource != null) {
|
||||||
|
metadataMap.put(RagRetrievalMetadataKeys.HIT_SOURCE, hitSource.name());
|
||||||
|
}
|
||||||
|
if (vectorScore != null) {
|
||||||
|
metadataMap.put(RagRetrievalMetadataKeys.VECTOR_SCORE, vectorScore);
|
||||||
|
}
|
||||||
|
if (keywordScore != null) {
|
||||||
|
metadataMap.put(RagRetrievalMetadataKeys.KEYWORD_SCORE, keywordScore);
|
||||||
|
}
|
||||||
|
if (!metadataMap.containsKey(RagRetrievalMetadataKeys.FUSION_SCORE)
|
||||||
|
&& hitSource == HitSource.BOTH
|
||||||
|
&& score != null) {
|
||||||
|
metadataMap.put(RagRetrievalMetadataKeys.FUSION_SCORE, score);
|
||||||
|
}
|
||||||
|
if (rank != null) {
|
||||||
|
metadataMap.put(RagRetrievalMetadataKeys.FINAL_RANK, rank);
|
||||||
|
}
|
||||||
|
document.setMetadataMap(metadataMap);
|
||||||
|
return document;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RagHit copy() {
|
||||||
|
RagHit copy = new RagHit();
|
||||||
|
copy.setDocumentId(documentId);
|
||||||
|
copy.setTitle(title);
|
||||||
|
copy.setContent(content);
|
||||||
|
copy.setScore(score);
|
||||||
|
copy.setHitSource(hitSource);
|
||||||
|
copy.setVectorScore(vectorScore);
|
||||||
|
copy.setKeywordScore(keywordScore);
|
||||||
|
copy.setRank(rank);
|
||||||
|
copy.setMetadata(metadata);
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getDocumentId() {
|
||||||
|
return documentId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDocumentId(Object documentId) {
|
||||||
|
this.documentId = documentId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTitle() {
|
||||||
|
return title;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTitle(String title) {
|
||||||
|
this.title = title;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getContent() {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setContent(String content) {
|
||||||
|
this.content = content;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getScore() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setScore(Double score) {
|
||||||
|
this.score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public HitSource getHitSource() {
|
||||||
|
return hitSource;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setHitSource(HitSource hitSource) {
|
||||||
|
this.hitSource = hitSource;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getVectorScore() {
|
||||||
|
return vectorScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setVectorScore(Double vectorScore) {
|
||||||
|
this.vectorScore = vectorScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getKeywordScore() {
|
||||||
|
return keywordScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setKeywordScore(Double keywordScore) {
|
||||||
|
this.keywordScore = keywordScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getRank() {
|
||||||
|
return rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRank(Integer rank) {
|
||||||
|
this.rank = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, Object> getMetadata() {
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMetadata(Map<String, Object> metadata) {
|
||||||
|
this.metadata = metadata == null ? new HashMap<String, Object>() : new HashMap<String, Object>(metadata);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Double asDouble(Object value) {
|
||||||
|
if (value instanceof Number) {
|
||||||
|
return ((Number) value).doubleValue();
|
||||||
|
}
|
||||||
|
if (value instanceof String) {
|
||||||
|
try {
|
||||||
|
return Double.valueOf((String) value);
|
||||||
|
} catch (NumberFormatException ignore) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Integer asInteger(Object value) {
|
||||||
|
if (value instanceof Number) {
|
||||||
|
return ((Number) value).intValue();
|
||||||
|
}
|
||||||
|
if (value instanceof String) {
|
||||||
|
try {
|
||||||
|
return Integer.valueOf((String) value);
|
||||||
|
} catch (NumberFormatException ignore) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class RagQuery {
|
||||||
|
|
||||||
|
private String query;
|
||||||
|
private RetrievalMode retrievalMode = RetrievalMode.HYBRID;
|
||||||
|
private Integer topK = 10;
|
||||||
|
private Double minScore;
|
||||||
|
private Map<String, Object> attributes = new HashMap<String, Object>();
|
||||||
|
|
||||||
|
public String getQuery() {
|
||||||
|
return query;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setQuery(String query) {
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RetrievalMode getRetrievalMode() {
|
||||||
|
return retrievalMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setRetrievalMode(RetrievalMode retrievalMode) {
|
||||||
|
this.retrievalMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTopK() {
|
||||||
|
return topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTopK(Integer topK) {
|
||||||
|
this.topK = topK == null || topK <= 0 ? 10 : topK;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMinScore() {
|
||||||
|
return minScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMinScore(Double minScore) {
|
||||||
|
this.minScore = minScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, Object> getAttributes() {
|
||||||
|
return attributes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setAttributes(Map<String, Object> attributes) {
|
||||||
|
this.attributes = attributes == null ? new HashMap<String, Object>() : attributes;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
import com.easyagents.core.model.rerank.RerankModel;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CompletableFuture;
|
||||||
|
|
||||||
|
public class RagRetrievalExecutor {
|
||||||
|
|
||||||
|
private final VectorRetriever vectorRetriever;
|
||||||
|
private final KeywordRetriever keywordRetriever;
|
||||||
|
private final FusionStrategy fusionStrategy;
|
||||||
|
|
||||||
|
public RagRetrievalExecutor(VectorRetriever vectorRetriever,
|
||||||
|
KeywordRetriever keywordRetriever,
|
||||||
|
FusionStrategy fusionStrategy) {
|
||||||
|
this.vectorRetriever = vectorRetriever;
|
||||||
|
this.keywordRetriever = keywordRetriever;
|
||||||
|
this.fusionStrategy = fusionStrategy == null ? new RrfFusionStrategy() : fusionStrategy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RagRetrievalResult retrieve(RagQuery query) {
|
||||||
|
RagQuery effectiveQuery = query == null ? new RagQuery() : query;
|
||||||
|
RetrievalMode mode = effectiveQuery.getRetrievalMode() == null
|
||||||
|
? RetrievalMode.HYBRID
|
||||||
|
: effectiveQuery.getRetrievalMode();
|
||||||
|
if (mode == RetrievalMode.VECTOR) {
|
||||||
|
return buildResult(normalizeHits(vectorRetriever == null
|
||||||
|
? Collections.<RagHit>emptyList()
|
||||||
|
: vectorRetriever.retrieve(effectiveQuery), HitSource.VECTOR), effectiveQuery.getTopK());
|
||||||
|
}
|
||||||
|
if (mode == RetrievalMode.KEYWORD) {
|
||||||
|
return buildResult(normalizeHits(keywordRetriever == null
|
||||||
|
? Collections.<RagHit>emptyList()
|
||||||
|
: keywordRetriever.retrieve(effectiveQuery), HitSource.KEYWORD), effectiveQuery.getTopK());
|
||||||
|
}
|
||||||
|
|
||||||
|
CompletableFuture<List<RagHit>> vectorFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier<List<RagHit>>() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> get() {
|
||||||
|
return vectorRetriever == null ? Collections.<RagHit>emptyList() : vectorRetriever.retrieve(effectiveQuery);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
CompletableFuture<List<RagHit>> keywordFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier<List<RagHit>>() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> get() {
|
||||||
|
return keywordRetriever == null ? Collections.<RagHit>emptyList() : keywordRetriever.retrieve(effectiveQuery);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
List<RagHit> vectorHits = normalizeHits(vectorFuture.join(), HitSource.VECTOR);
|
||||||
|
List<RagHit> keywordHits = normalizeHits(keywordFuture.join(), HitSource.KEYWORD);
|
||||||
|
return buildResult(fusionStrategy.fuse(vectorHits, keywordHits, effectiveQuery.getTopK()), effectiveQuery.getTopK());
|
||||||
|
}
|
||||||
|
|
||||||
|
public RagRetrievalResult rerank(String query, List<RagHit> hits, RerankModel rerankModel, int topK) {
|
||||||
|
if (rerankModel == null || hits == null || hits.isEmpty()) {
|
||||||
|
return buildResult(hits, topK);
|
||||||
|
}
|
||||||
|
List<Document> rerankInput = new ArrayList<Document>(hits.size());
|
||||||
|
Map<String, RagHit> hitMap = new LinkedHashMap<String, RagHit>();
|
||||||
|
for (RagHit hit : hits) {
|
||||||
|
if (hit == null || hit.getDocumentId() == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
rerankInput.add(hit.toDocument());
|
||||||
|
hitMap.put(String.valueOf(hit.getDocumentId()), hit.copy());
|
||||||
|
}
|
||||||
|
List<Document> rerankedDocuments = rerankModel.rerank(query, rerankInput);
|
||||||
|
List<RagHit> rerankedHits = new ArrayList<RagHit>();
|
||||||
|
if (rerankedDocuments != null) {
|
||||||
|
for (Document rerankedDocument : rerankedDocuments) {
|
||||||
|
if (rerankedDocument == null || rerankedDocument.getId() == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
RagHit original = hitMap.get(String.valueOf(rerankedDocument.getId()));
|
||||||
|
if (original == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
original.setContent(rerankedDocument.getContent());
|
||||||
|
original.setTitle(rerankedDocument.getTitle());
|
||||||
|
original.setScore(rerankedDocument.getScore());
|
||||||
|
if (rerankedDocument.getScore() != null) {
|
||||||
|
original.getMetadata().put(RagRetrievalMetadataKeys.RERANK_SCORE, rerankedDocument.getScore());
|
||||||
|
}
|
||||||
|
rerankedHits.add(original);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sortByScore(rerankedHits);
|
||||||
|
assignRank(rerankedHits);
|
||||||
|
return buildResult(rerankedHits, topK);
|
||||||
|
}
|
||||||
|
|
||||||
|
private RagRetrievalResult buildResult(List<RagHit> hits, int topK) {
|
||||||
|
List<RagHit> effectiveHits = hits == null ? new ArrayList<RagHit>() : new ArrayList<RagHit>(hits);
|
||||||
|
if (topK > 0 && effectiveHits.size() > topK) {
|
||||||
|
effectiveHits = new ArrayList<RagHit>(effectiveHits.subList(0, topK));
|
||||||
|
}
|
||||||
|
assignRank(effectiveHits);
|
||||||
|
RagRetrievalResult result = new RagRetrievalResult();
|
||||||
|
result.setHits(effectiveHits);
|
||||||
|
result.setTotal(effectiveHits.size());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<RagHit> normalizeHits(List<RagHit> hits, HitSource defaultSource) {
|
||||||
|
List<RagHit> normalized = new ArrayList<RagHit>();
|
||||||
|
if (hits == null) {
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
for (RagHit hit : hits) {
|
||||||
|
if (hit == null || hit.getDocumentId() == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
RagHit copy = hit.copy();
|
||||||
|
if (copy.getHitSource() == null) {
|
||||||
|
copy.setHitSource(defaultSource);
|
||||||
|
}
|
||||||
|
if (defaultSource == HitSource.VECTOR && copy.getVectorScore() == null) {
|
||||||
|
copy.setVectorScore(copy.getScore());
|
||||||
|
}
|
||||||
|
if (defaultSource == HitSource.KEYWORD && copy.getKeywordScore() == null) {
|
||||||
|
copy.setKeywordScore(copy.getScore());
|
||||||
|
}
|
||||||
|
normalized.add(copy);
|
||||||
|
}
|
||||||
|
assignRank(normalized);
|
||||||
|
return normalized;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void sortByScore(List<RagHit> hits) {
|
||||||
|
Collections.sort(hits, new Comparator<RagHit>() {
|
||||||
|
@Override
|
||||||
|
public int compare(RagHit left, RagHit right) {
|
||||||
|
return Double.compare(right.getScore() == null ? 0D : right.getScore(),
|
||||||
|
left.getScore() == null ? 0D : left.getScore());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assignRank(List<RagHit> hits) {
|
||||||
|
for (int i = 0; i < hits.size(); i++) {
|
||||||
|
RagHit hit = hits.get(i);
|
||||||
|
hit.setRank(i + 1);
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1);
|
||||||
|
if (hit.getHitSource() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource().name());
|
||||||
|
}
|
||||||
|
if (hit.getVectorScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore());
|
||||||
|
}
|
||||||
|
if (hit.getKeywordScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore());
|
||||||
|
}
|
||||||
|
if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||||
|
&& hit.getHitSource() == HitSource.BOTH
|
||||||
|
&& hit.getScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
public final class RagRetrievalMetadataKeys {
|
||||||
|
|
||||||
|
private RagRetrievalMetadataKeys() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final String HIT_SOURCE = "hitSource";
|
||||||
|
public static final String VECTOR_SCORE = "vectorScore";
|
||||||
|
public static final String KEYWORD_SCORE = "keywordScore";
|
||||||
|
public static final String FUSION_SCORE = "fusionScore";
|
||||||
|
public static final String RERANK_SCORE = "rerankScore";
|
||||||
|
public static final String VECTOR_RANK = "vectorRank";
|
||||||
|
public static final String KEYWORD_RANK = "keywordRank";
|
||||||
|
public static final String FINAL_RANK = "rank";
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RagRetrievalResult {
|
||||||
|
|
||||||
|
private List<RagHit> hits = new ArrayList<RagHit>();
|
||||||
|
private Integer total;
|
||||||
|
|
||||||
|
public List<RagHit> getHits() {
|
||||||
|
return hits;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setHits(List<RagHit> hits) {
|
||||||
|
this.hits = hits == null ? new ArrayList<RagHit>() : hits;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getTotal() {
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTotal(Integer total) {
|
||||||
|
this.total = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public final class RagScoreNormalizer {
|
||||||
|
|
||||||
|
private RagScoreNormalizer() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void normalize(List<Document> documents, RetrievalMode retrievalMode, boolean reranked) {
|
||||||
|
if (documents == null || documents.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (reranked) {
|
||||||
|
normalizeRerankScores(documents);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
RetrievalMode effectiveMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode;
|
||||||
|
if (effectiveMode == RetrievalMode.VECTOR) {
|
||||||
|
for (Document document : documents) {
|
||||||
|
document.setScore(clamp01(readRawScore(document, RagRetrievalMetadataKeys.VECTOR_SCORE, document == null ? null : document.getScore())));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (effectiveMode == RetrievalMode.KEYWORD) {
|
||||||
|
for (Document document : documents) {
|
||||||
|
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.KEYWORD_SCORE, document == null ? null : document.getScore());
|
||||||
|
if (rawScore == null || rawScore <= 0D) {
|
||||||
|
document.setScore(0D);
|
||||||
|
} else {
|
||||||
|
document.setScore(clamp01(rawScore / (1D + rawScore)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
double maxHybridRrfScore = 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D);
|
||||||
|
for (Document document : documents) {
|
||||||
|
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.FUSION_SCORE, document == null ? null : document.getScore());
|
||||||
|
if (rawScore == null || rawScore <= 0D) {
|
||||||
|
document.setScore(0D);
|
||||||
|
} else {
|
||||||
|
document.setScore(clamp01(rawScore / maxHybridRrfScore));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void normalizeRerankScores(List<Document> documents) {
|
||||||
|
List<Double> rawScores = new ArrayList<Double>(documents.size());
|
||||||
|
boolean allPresent = true;
|
||||||
|
Double min = null;
|
||||||
|
Double max = null;
|
||||||
|
for (Document document : documents) {
|
||||||
|
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.RERANK_SCORE, document == null ? null : document.getScore());
|
||||||
|
rawScores.add(rawScore);
|
||||||
|
if (rawScore == null) {
|
||||||
|
allPresent = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
min = min == null ? rawScore : Math.min(min, rawScore);
|
||||||
|
max = max == null ? rawScore : Math.max(max, rawScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allPresent && min != null && max != null && Double.compare(max, min) != 0) {
|
||||||
|
for (int i = 0; i < documents.size(); i++) {
|
||||||
|
Double rawScore = rawScores.get(i);
|
||||||
|
documents.get(i).setScore(clamp01((rawScore - min) / (max - min)));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (documents.size() == 1) {
|
||||||
|
documents.get(0).setScore(1D);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size = documents.size();
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
documents.get(i).setScore(clamp01(1D - ((double) i / (double) (size - 1))));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Double readRawScore(Document document, String metadataKey, Double fallback) {
|
||||||
|
if (document == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
Object metadataValue = document.getMetadata(metadataKey);
|
||||||
|
if (metadataValue instanceof Number) {
|
||||||
|
return ((Number) metadataValue).doubleValue();
|
||||||
|
}
|
||||||
|
if (metadataValue instanceof String) {
|
||||||
|
try {
|
||||||
|
return Double.valueOf((String) metadataValue);
|
||||||
|
} catch (NumberFormatException ignore) {
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static double clamp01(Double value) {
|
||||||
|
if (value == null || value.isNaN() || value.isInfinite()) {
|
||||||
|
return 0D;
|
||||||
|
}
|
||||||
|
if (value < 0D) {
|
||||||
|
return 0D;
|
||||||
|
}
|
||||||
|
if (value > 1D) {
|
||||||
|
return 1D;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
public enum RetrievalMode {
|
||||||
|
VECTOR,
|
||||||
|
KEYWORD,
|
||||||
|
HYBRID;
|
||||||
|
|
||||||
|
public static RetrievalMode from(String value) {
|
||||||
|
if (value == null || value.trim().isEmpty()) {
|
||||||
|
return HYBRID;
|
||||||
|
}
|
||||||
|
for (RetrievalMode mode : values()) {
|
||||||
|
if (mode.name().equalsIgnoreCase(value.trim())) {
|
||||||
|
return mode;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported retrieval mode: " + value);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class RrfFusionStrategy implements FusionStrategy {
|
||||||
|
|
||||||
|
public static final int DEFAULT_RRF_K = 60;
|
||||||
|
|
||||||
|
private final int rrfK;
|
||||||
|
|
||||||
|
public RrfFusionStrategy() {
|
||||||
|
this(DEFAULT_RRF_K);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RrfFusionStrategy(int rrfK) {
|
||||||
|
this.rrfK = rrfK <= 0 ? DEFAULT_RRF_K : rrfK;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<RagHit> fuse(List<RagHit> vectorHits, List<RagHit> keywordHits, int topK) {
|
||||||
|
Map<String, RagHit> mergedHits = new LinkedHashMap<String, RagHit>();
|
||||||
|
Map<String, Double> fusionScores = new LinkedHashMap<String, Double>();
|
||||||
|
|
||||||
|
mergeHits(mergedHits, fusionScores, vectorHits, HitSource.VECTOR);
|
||||||
|
mergeHits(mergedHits, fusionScores, keywordHits, HitSource.KEYWORD);
|
||||||
|
|
||||||
|
List<RagHit> results = new ArrayList<RagHit>(mergedHits.values());
|
||||||
|
Collections.sort(results, new Comparator<RagHit>() {
|
||||||
|
@Override
|
||||||
|
public int compare(RagHit left, RagHit right) {
|
||||||
|
int scoreCompare = Double.compare(nullSafeScore(right.getScore()), nullSafeScore(left.getScore()));
|
||||||
|
if (scoreCompare != 0) {
|
||||||
|
return scoreCompare;
|
||||||
|
}
|
||||||
|
return String.valueOf(left.getDocumentId()).compareTo(String.valueOf(right.getDocumentId()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assignRank(results);
|
||||||
|
if (topK > 0 && results.size() > topK) {
|
||||||
|
return new ArrayList<RagHit>(results.subList(0, topK));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mergeHits(Map<String, RagHit> mergedHits,
|
||||||
|
Map<String, Double> fusionScores,
|
||||||
|
List<RagHit> incomingHits,
|
||||||
|
HitSource source) {
|
||||||
|
if (incomingHits == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (int index = 0; index < incomingHits.size(); index++) {
|
||||||
|
RagHit incoming = incomingHits.get(index);
|
||||||
|
if (incoming == null || incoming.getDocumentId() == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String key = String.valueOf(incoming.getDocumentId());
|
||||||
|
RagHit merged = mergedHits.get(key);
|
||||||
|
if (merged == null) {
|
||||||
|
merged = incoming.copy();
|
||||||
|
merged.setHitSource(source);
|
||||||
|
mergedHits.put(key, merged);
|
||||||
|
fusionScores.put(key, 0D);
|
||||||
|
} else {
|
||||||
|
mergeBaseFields(merged, incoming);
|
||||||
|
merged.setHitSource(HitSource.merge(merged.getHitSource(), source));
|
||||||
|
}
|
||||||
|
|
||||||
|
int rank = index + 1;
|
||||||
|
if (source == HitSource.VECTOR) {
|
||||||
|
merged.setVectorScore(incoming.getVectorScore() != null ? incoming.getVectorScore() : incoming.getScore());
|
||||||
|
merged.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_RANK, rank);
|
||||||
|
} else {
|
||||||
|
merged.setKeywordScore(incoming.getKeywordScore() != null ? incoming.getKeywordScore() : incoming.getScore());
|
||||||
|
merged.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_RANK, rank);
|
||||||
|
}
|
||||||
|
double fusedScore = fusionScores.get(key) + (1D / (rrfK + rank));
|
||||||
|
fusionScores.put(key, fusedScore);
|
||||||
|
merged.setScore(fusedScore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mergeBaseFields(RagHit target, RagHit incoming) {
|
||||||
|
if ((target.getTitle() == null || target.getTitle().isEmpty()) && incoming.getTitle() != null) {
|
||||||
|
target.setTitle(incoming.getTitle());
|
||||||
|
}
|
||||||
|
if ((target.getContent() == null || target.getContent().isEmpty()) && incoming.getContent() != null) {
|
||||||
|
target.setContent(incoming.getContent());
|
||||||
|
}
|
||||||
|
if (incoming.getMetadata() != null && !incoming.getMetadata().isEmpty()) {
|
||||||
|
target.getMetadata().putAll(incoming.getMetadata());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assignRank(List<RagHit> results) {
|
||||||
|
for (int i = 0; i < results.size(); i++) {
|
||||||
|
RagHit hit = results.get(i);
|
||||||
|
hit.setRank(i + 1);
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1);
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource() == null ? null : hit.getHitSource().name());
|
||||||
|
if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE) && hit.getScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore());
|
||||||
|
}
|
||||||
|
if (hit.getVectorScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore());
|
||||||
|
}
|
||||||
|
if (hit.getKeywordScore() != null) {
|
||||||
|
hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private double nullSafeScore(Double value) {
|
||||||
|
return value == null ? 0D : value;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface VectorRetriever {
|
||||||
|
|
||||||
|
List<RagHit> retrieve(RagQuery query);
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
import com.easyagents.core.model.rerank.RerankModel;
|
||||||
|
import com.easyagents.core.model.rerank.RerankOptions;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RagRetrievalExecutorTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldOnlyCallVectorRetrieverInVectorMode() {
|
||||||
|
final List<String> invoked = new ArrayList<String>();
|
||||||
|
RagRetrievalExecutor executor = new RagRetrievalExecutor(
|
||||||
|
new VectorRetriever() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> retrieve(RagQuery query) {
|
||||||
|
invoked.add("vector");
|
||||||
|
return Arrays.asList(hit("1", "vector", 0.9D, HitSource.VECTOR));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new KeywordRetriever() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> retrieve(RagQuery query) {
|
||||||
|
invoked.add("keyword");
|
||||||
|
return Arrays.asList(hit("2", "keyword", 5D, HitSource.KEYWORD));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new RrfFusionStrategy()
|
||||||
|
);
|
||||||
|
|
||||||
|
RagQuery query = new RagQuery();
|
||||||
|
query.setRetrievalMode(RetrievalMode.VECTOR);
|
||||||
|
query.setTopK(10);
|
||||||
|
|
||||||
|
RagRetrievalResult result = executor.retrieve(query);
|
||||||
|
|
||||||
|
Assert.assertEquals(Arrays.asList("vector"), invoked);
|
||||||
|
Assert.assertEquals(1, result.getHits().size());
|
||||||
|
Assert.assertEquals(HitSource.VECTOR, result.getHits().get(0).getHitSource());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldRerankAfterHybridFusionAndKeepHitSource() {
|
||||||
|
RagRetrievalExecutor executor = new RagRetrievalExecutor(
|
||||||
|
new VectorRetriever() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> retrieve(RagQuery query) {
|
||||||
|
return Arrays.asList(
|
||||||
|
hit("1", "alpha", 0.9D, HitSource.VECTOR),
|
||||||
|
hit("2", "beta", 0.8D, HitSource.VECTOR)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new KeywordRetriever() {
|
||||||
|
@Override
|
||||||
|
public List<RagHit> retrieve(RagQuery query) {
|
||||||
|
return Arrays.asList(
|
||||||
|
hit("2", "beta", 4.1D, HitSource.KEYWORD),
|
||||||
|
hit("3", "gamma", 3.9D, HitSource.KEYWORD)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
new RrfFusionStrategy()
|
||||||
|
);
|
||||||
|
|
||||||
|
RagQuery query = new RagQuery();
|
||||||
|
query.setRetrievalMode(RetrievalMode.HYBRID);
|
||||||
|
query.setTopK(10);
|
||||||
|
|
||||||
|
RagRetrievalResult retrieved = executor.retrieve(query);
|
||||||
|
RagRetrievalResult reranked = executor.rerank("query", retrieved.getHits(), new ReverseRerankModel(), 10);
|
||||||
|
|
||||||
|
Assert.assertEquals("3", String.valueOf(reranked.getHits().get(0).getDocumentId()));
|
||||||
|
Assert.assertEquals("2", String.valueOf(reranked.getHits().get(2).getDocumentId()));
|
||||||
|
Assert.assertEquals(HitSource.BOTH, reranked.getHits().get(2).getHitSource());
|
||||||
|
}
|
||||||
|
|
||||||
|
private RagHit hit(String id, String content, double score, HitSource source) {
|
||||||
|
RagHit hit = new RagHit();
|
||||||
|
hit.setDocumentId(id);
|
||||||
|
hit.setContent(content);
|
||||||
|
hit.setScore(score);
|
||||||
|
hit.setHitSource(source);
|
||||||
|
if (source == HitSource.VECTOR) {
|
||||||
|
hit.setVectorScore(score);
|
||||||
|
} else if (source == HitSource.KEYWORD) {
|
||||||
|
hit.setKeywordScore(score);
|
||||||
|
}
|
||||||
|
return hit;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ReverseRerankModel implements RerankModel {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Document> rerank(String query, List<Document> documents, RerankOptions options) {
|
||||||
|
List<Document> result = new ArrayList<Document>();
|
||||||
|
double score = documents.size();
|
||||||
|
for (int i = documents.size() - 1; i >= 0; i--) {
|
||||||
|
Document document = documents.get(i);
|
||||||
|
document.setScore(score--);
|
||||||
|
result.add(document);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RagScoreNormalizerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldNormalizeKeywordScoresToZeroAndOneRange() {
|
||||||
|
Document first = document(1, 9D, RagRetrievalMetadataKeys.KEYWORD_SCORE);
|
||||||
|
Document second = document(2, 0D, RagRetrievalMetadataKeys.KEYWORD_SCORE);
|
||||||
|
|
||||||
|
RagScoreNormalizer.normalize(Arrays.asList(first, second), RetrievalMode.KEYWORD, false);
|
||||||
|
|
||||||
|
Assert.assertEquals(0.9D, first.getScore(), 0.0001D);
|
||||||
|
Assert.assertEquals(0D, second.getScore(), 0.0001D);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldNormalizeHybridFusionScoreByRrfUpperBound() {
|
||||||
|
Document document = document(1, 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D), RagRetrievalMetadataKeys.FUSION_SCORE);
|
||||||
|
|
||||||
|
RagScoreNormalizer.normalize(Arrays.asList(document), RetrievalMode.HYBRID, false);
|
||||||
|
|
||||||
|
Assert.assertEquals(1D, document.getScore(), 0.0001D);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldNormalizeRerankScoresByMinMax() {
|
||||||
|
List<Document> documents = Arrays.asList(
|
||||||
|
document(1, 10D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||||
|
document(2, 20D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||||
|
document(3, 30D, RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||||
|
);
|
||||||
|
|
||||||
|
RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true);
|
||||||
|
|
||||||
|
Assert.assertEquals(0D, documents.get(0).getScore(), 0.0001D);
|
||||||
|
Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D);
|
||||||
|
Assert.assertEquals(1D, documents.get(2).getScore(), 0.0001D);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldFallbackToRankBasedNormalizationWhenRerankScoresAreEqual() {
|
||||||
|
List<Document> documents = Arrays.asList(
|
||||||
|
document(1, 5D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||||
|
document(2, 5D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||||
|
document(3, 5D, RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||||
|
);
|
||||||
|
|
||||||
|
RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true);
|
||||||
|
|
||||||
|
Assert.assertEquals(1D, documents.get(0).getScore(), 0.0001D);
|
||||||
|
Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D);
|
||||||
|
Assert.assertEquals(0D, documents.get(2).getScore(), 0.0001D);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Document document(Object id, Double score, String metadataKey) {
|
||||||
|
Document document = new Document();
|
||||||
|
document.setId(id);
|
||||||
|
document.setScore(score);
|
||||||
|
document.addMetadata(metadataKey, score);
|
||||||
|
return document;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package com.easyagents.rag.retrieval;
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RrfFusionStrategyTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldMarkDuplicateHitsAsBothAndRankFirst() {
|
||||||
|
RrfFusionStrategy strategy = new RrfFusionStrategy();
|
||||||
|
|
||||||
|
RagHit vectorOnly = hit("1", 0.91D, HitSource.VECTOR);
|
||||||
|
RagHit bothVector = hit("2", 0.88D, HitSource.VECTOR);
|
||||||
|
RagHit bothKeyword = hit("2", 5.2D, HitSource.KEYWORD);
|
||||||
|
RagHit keywordOnly = hit("3", 4.8D, HitSource.KEYWORD);
|
||||||
|
|
||||||
|
List<RagHit> fused = strategy.fuse(
|
||||||
|
Arrays.asList(vectorOnly, bothVector),
|
||||||
|
Arrays.asList(bothKeyword, keywordOnly),
|
||||||
|
10
|
||||||
|
);
|
||||||
|
|
||||||
|
Assert.assertEquals(3, fused.size());
|
||||||
|
Assert.assertEquals("2", String.valueOf(fused.get(0).getDocumentId()));
|
||||||
|
Assert.assertEquals(HitSource.BOTH, fused.get(0).getHitSource());
|
||||||
|
Assert.assertEquals("1", String.valueOf(fused.get(1).getDocumentId()));
|
||||||
|
Assert.assertEquals("3", String.valueOf(fused.get(2).getDocumentId()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private RagHit hit(String id, double score, HitSource source) {
|
||||||
|
RagHit hit = new RagHit();
|
||||||
|
hit.setDocumentId(id);
|
||||||
|
hit.setContent("doc-" + id);
|
||||||
|
hit.setScore(score);
|
||||||
|
hit.setHitSource(source);
|
||||||
|
if (source == HitSource.VECTOR) {
|
||||||
|
hit.setVectorScore(score);
|
||||||
|
} else if (source == HitSource.KEYWORD) {
|
||||||
|
hit.setKeywordScore(score);
|
||||||
|
}
|
||||||
|
return hit;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,6 +37,11 @@
|
|||||||
<artifactId>jackson-databind</artifactId>
|
<artifactId>jackson-databind</artifactId>
|
||||||
<version>2.15.2</version> <!-- 或与Elasticsearch客户端兼容的版本 -->
|
<version>2.15.2</version> <!-- 或与Elasticsearch客户端兼容的版本 -->
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ import co.elastic.clients.elasticsearch.ElasticsearchClient;
|
|||||||
import co.elastic.clients.elasticsearch.core.*;
|
import co.elastic.clients.elasticsearch.core.*;
|
||||||
import co.elastic.clients.elasticsearch.core.bulk.BulkOperation;
|
import co.elastic.clients.elasticsearch.core.bulk.BulkOperation;
|
||||||
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation;
|
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation;
|
||||||
|
import co.elastic.clients.elasticsearch.core.search.SourceConfig;
|
||||||
import co.elastic.clients.json.JsonData;
|
import co.elastic.clients.json.JsonData;
|
||||||
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
|
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
|
||||||
import co.elastic.clients.transport.ElasticsearchTransport;
|
import co.elastic.clients.transport.ElasticsearchTransport;
|
||||||
import co.elastic.clients.transport.rest_client.RestClientTransport;
|
import co.elastic.clients.transport.rest_client.RestClientTransport;
|
||||||
import com.easyagents.core.document.Document;
|
import com.easyagents.core.document.Document;
|
||||||
import com.easyagents.search.engine.service.DocumentSearcher;
|
import com.easyagents.search.engine.service.DocumentSearcher;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||||
import org.apache.http.HttpHost;
|
import org.apache.http.HttpHost;
|
||||||
import org.apache.http.auth.AuthScope;
|
import org.apache.http.auth.AuthScope;
|
||||||
import org.apache.http.auth.UsernamePasswordCredentials;
|
import org.apache.http.auth.UsernamePasswordCredentials;
|
||||||
@@ -88,13 +91,7 @@ public class ElasticSearcher implements DocumentSearcher {
|
|||||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||||
|
|
||||||
Map<String, Object> source = new HashMap<>();
|
Map<String, Object> source = buildSource(document);
|
||||||
source.put("id", document.getId());
|
|
||||||
source.put("content", document.getContent());
|
|
||||||
if (document.getTitle() != null) {
|
|
||||||
source.put("title", document.getTitle());
|
|
||||||
}
|
|
||||||
|
|
||||||
String documentId = document.getId().toString();
|
String documentId = document.getId().toString();
|
||||||
IndexOperation<?> indexOp = IndexOperation.of(i -> i
|
IndexOperation<?> indexOp = IndexOperation.of(i -> i
|
||||||
.index(esConfig.getIndexName())
|
.index(esConfig.getIndexName())
|
||||||
@@ -116,7 +113,7 @@ public class ElasticSearcher implements DocumentSearcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Document> searchDocuments(String keyword, int count) {
|
public List<Document> searchDocuments(KeywordSearchRequest request) {
|
||||||
RestClient restClient = null;
|
RestClient restClient = null;
|
||||||
ElasticsearchTransport transport = null;
|
ElasticsearchTransport transport = null;
|
||||||
|
|
||||||
@@ -125,21 +122,16 @@ public class ElasticSearcher implements DocumentSearcher {
|
|||||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||||
|
|
||||||
SearchRequest request = SearchRequest.of(s -> s
|
SearchResponse<Map> response = client.search(buildSearchRequest(request), Map.class);
|
||||||
.index(esConfig.getIndexName())
|
|
||||||
.size(count)
|
|
||||||
.query(q -> q
|
|
||||||
.match(m -> m
|
|
||||||
.field("title")
|
|
||||||
.field("content")
|
|
||||||
.query(keyword)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
SearchResponse<Document> response = client.search(request, Document.class);
|
|
||||||
List<Document> results = new ArrayList<>();
|
List<Document> results = new ArrayList<>();
|
||||||
response.hits().hits().forEach(hit -> results.add(hit.source()));
|
response.hits().hits().forEach(hit -> {
|
||||||
|
Map source = hit.source();
|
||||||
|
Document document = toDocument(hit.id(), source, hit.score());
|
||||||
|
if (document == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
results.add(document);
|
||||||
|
});
|
||||||
return results;
|
return results;
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -193,14 +185,17 @@ public class ElasticSearcher implements DocumentSearcher {
|
|||||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||||
|
|
||||||
UpdateRequest<Document, Object> request = UpdateRequest.of(u -> u
|
UpdateRequest<Map<String, Object>, Map<String, Object>> request = UpdateRequest.of(u -> u
|
||||||
.index(esConfig.getIndexName())
|
.index(esConfig.getIndexName())
|
||||||
.id(document.getId().toString())
|
.id(document.getId().toString())
|
||||||
.doc(document)
|
.doc(buildSource(document))
|
||||||
);
|
);
|
||||||
|
|
||||||
UpdateResponse<Document> response = client.update(request, Object.class);
|
@SuppressWarnings("unchecked")
|
||||||
return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated;
|
Class<Map<String, Object>> documentClass = (Class<Map<String, Object>>) (Class<?>) Map.class;
|
||||||
|
UpdateResponse<Map<String, Object>> response = client.update(request, documentClass);
|
||||||
|
return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated
|
||||||
|
|| response.result() == co.elastic.clients.elasticsearch._types.Result.NoOp;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
LOG.error("Error updating document with id: " + document.getId(), e);
|
LOG.error("Error updating document with id: " + document.getId(), e);
|
||||||
return false;
|
return false;
|
||||||
@@ -220,4 +215,88 @@ public class ElasticSearcher implements DocumentSearcher {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private Document toDocument(String hitId, Map source, Double score) {
|
||||||
|
if (source == null || source.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
Document document = new Document();
|
||||||
|
Object id = source.get("id");
|
||||||
|
document.setId(id != null ? id : hitId);
|
||||||
|
|
||||||
|
Object title = source.get("title");
|
||||||
|
if (title != null) {
|
||||||
|
document.setTitle(String.valueOf(title));
|
||||||
|
}
|
||||||
|
|
||||||
|
Object content = source.get("content");
|
||||||
|
if (content != null) {
|
||||||
|
document.setContent(String.valueOf(content));
|
||||||
|
}
|
||||||
|
|
||||||
|
Object metadataMap = source.get("metadataMap");
|
||||||
|
if (metadataMap instanceof Map<?, ?>) {
|
||||||
|
document.setMetadataMap(new HashMap<>((Map<String, Object>) metadataMap));
|
||||||
|
}
|
||||||
|
|
||||||
|
document.setScore(score);
|
||||||
|
return document;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> buildSource(Document document) {
|
||||||
|
Map<String, Object> source = new HashMap<String, Object>();
|
||||||
|
source.put("id", document.getId());
|
||||||
|
source.put("content", document.getContent());
|
||||||
|
if (document.getTitle() != null) {
|
||||||
|
source.put("title", document.getTitle());
|
||||||
|
}
|
||||||
|
if (document.getMetadataMap() != null && !document.getMetadataMap().isEmpty()) {
|
||||||
|
source.put("metadataMap", new HashMap<String, Object>(document.getMetadataMap()));
|
||||||
|
Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID);
|
||||||
|
if (knowledgeId != null) {
|
||||||
|
source.put(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return source;
|
||||||
|
}
|
||||||
|
|
||||||
|
SearchRequest buildSearchRequest(KeywordSearchRequest request) {
|
||||||
|
KeywordSearchRequest effectiveRequest = request == null ? new KeywordSearchRequest() : request;
|
||||||
|
return SearchRequest.of(s -> s
|
||||||
|
.index(esConfig.getIndexName())
|
||||||
|
.size(effectiveRequest.getCount())
|
||||||
|
.source(SourceConfig.of(sc -> sc.filter(f -> f.includes("id", "title", "content", "metadataMap"))))
|
||||||
|
.query(q -> q.bool(b -> {
|
||||||
|
b.must(m -> m.multiMatch(mm -> mm
|
||||||
|
.query(effectiveRequest.getKeyword())
|
||||||
|
.fields("title", "content")
|
||||||
|
));
|
||||||
|
if (effectiveRequest.getKnowledgeId() != null && !effectiveRequest.getKnowledgeId().trim().isEmpty()) {
|
||||||
|
b.filter(f -> f.term(t -> t
|
||||||
|
.field(KeywordSearchMetadataKeys.KNOWLEDGE_ID)
|
||||||
|
.value(v -> v.stringValue(effectiveRequest.getKnowledgeId().trim()))
|
||||||
|
));
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean checkAvailable() {
|
||||||
|
RestClient restClient = null;
|
||||||
|
ElasticsearchTransport transport = null;
|
||||||
|
try {
|
||||||
|
restClient = buildRestClient();
|
||||||
|
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||||
|
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||||
|
return client.info() != null;
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.error("Elasticsearch availability check failed", e);
|
||||||
|
return false;
|
||||||
|
} finally {
|
||||||
|
closeResources(transport, restClient);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package com.easyagents.engine.es;
|
||||||
|
|
||||||
|
import co.elastic.clients.elasticsearch.core.SearchRequest;
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ElasticSearcherQueryBuilderTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldBuildSearchRequestWithMultiMatchAndKnowledgeFilter() {
|
||||||
|
ElasticSearcher searcher = new ElasticSearcher(config());
|
||||||
|
KeywordSearchRequest request = KeywordSearchRequest.of("客服", 5);
|
||||||
|
request.setKnowledgeId("100");
|
||||||
|
|
||||||
|
SearchRequest searchRequest = searcher.buildSearchRequest(request);
|
||||||
|
|
||||||
|
Assert.assertEquals(5, searchRequest.size().intValue());
|
||||||
|
Assert.assertNotNull(searchRequest.query().bool());
|
||||||
|
Assert.assertEquals(1, searchRequest.query().bool().must().size());
|
||||||
|
Assert.assertNotNull(searchRequest.query().bool().must().get(0).multiMatch());
|
||||||
|
Assert.assertEquals(2, searchRequest.query().bool().must().get(0).multiMatch().fields().size());
|
||||||
|
Assert.assertEquals(1, searchRequest.query().bool().filter().size());
|
||||||
|
Assert.assertEquals("knowledgeId", searchRequest.query().bool().filter().get(0).term().field());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldExtractKnowledgeIdToTopLevelSource() {
|
||||||
|
ElasticSearcher searcher = new ElasticSearcher(config());
|
||||||
|
Document document = new Document();
|
||||||
|
document.setId("1");
|
||||||
|
document.setTitle("title");
|
||||||
|
document.setContent("content");
|
||||||
|
document.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100");
|
||||||
|
|
||||||
|
Map<String, Object> source = searcher.buildSource(document);
|
||||||
|
|
||||||
|
Assert.assertEquals("100", source.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID));
|
||||||
|
Assert.assertTrue(source.get("metadataMap") instanceof Map);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ESConfig config() {
|
||||||
|
ESConfig config = new ESConfig();
|
||||||
|
config.setHost("http://127.0.0.1:9200");
|
||||||
|
config.setUserName("elastic");
|
||||||
|
config.setPassword("elastic");
|
||||||
|
config.setIndexName("easyflow");
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -51,5 +51,10 @@
|
|||||||
<groupId>com.easyagents</groupId>
|
<groupId>com.easyagents</groupId>
|
||||||
<artifactId>easy-agents-search-engine-service</artifactId>
|
<artifactId>easy-agents-search-engine-service</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ package com.easyagents.search.engine.lucene;
|
|||||||
|
|
||||||
import com.easyagents.core.document.Document;
|
import com.easyagents.core.document.Document;
|
||||||
import com.easyagents.search.engine.service.DocumentSearcher;
|
import com.easyagents.search.engine.service.DocumentSearcher;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||||
import org.apache.lucene.analysis.Analyzer;
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.StringField;
|
import org.apache.lucene.document.StringField;
|
||||||
@@ -78,7 +80,7 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
if (document.getTitle() != null) {
|
if (document.getTitle() != null) {
|
||||||
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
||||||
}
|
}
|
||||||
|
appendKnowledgeId(document, luceneDoc);
|
||||||
|
|
||||||
indexWriter.addDocument(luceneDoc);
|
indexWriter.addDocument(luceneDoc);
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
@@ -127,7 +129,7 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
if (document.getTitle() != null) {
|
if (document.getTitle() != null) {
|
||||||
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
||||||
}
|
}
|
||||||
|
appendKnowledgeId(document, luceneDoc);
|
||||||
indexWriter.updateDocument(term, luceneDoc);
|
indexWriter.updateDocument(term, luceneDoc);
|
||||||
indexWriter.commit();
|
indexWriter.commit();
|
||||||
return true;
|
return true;
|
||||||
@@ -140,18 +142,21 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Document> searchDocuments(String keyword, int count) {
|
public List<Document> searchDocuments(KeywordSearchRequest request) {
|
||||||
List<Document> results = new ArrayList<>();
|
List<Document> results = new ArrayList<>();
|
||||||
try (IndexReader reader = DirectoryReader.open(directory)) {
|
try (IndexReader reader = DirectoryReader.open(directory)) {
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
Query query = buildQuery(keyword);
|
Query query = buildQuery(request);
|
||||||
TopDocs topDocs = searcher.search(query, count);
|
TopDocs topDocs = searcher.search(query, request == null ? 10 : request.getCount());
|
||||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||||
org.apache.lucene.document.Document doc = searcher.doc(scoreDoc.doc);
|
org.apache.lucene.document.Document doc = searcher.doc(scoreDoc.doc);
|
||||||
Document resultDoc = new Document();
|
Document resultDoc = new Document();
|
||||||
resultDoc.setId(doc.get("id"));
|
resultDoc.setId(doc.get("id"));
|
||||||
resultDoc.setContent(doc.get("content"));
|
resultDoc.setContent(doc.get("content"));
|
||||||
resultDoc.setTitle(doc.get("title"));
|
resultDoc.setTitle(doc.get("title"));
|
||||||
|
if (doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID) != null) {
|
||||||
|
resultDoc.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID));
|
||||||
|
}
|
||||||
|
|
||||||
resultDoc.setScore((double) scoreDoc.score);
|
resultDoc.setScore((double) scoreDoc.score);
|
||||||
|
|
||||||
@@ -164,9 +169,10 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Query buildQuery(String keyword) {
|
Query buildQuery(KeywordSearchRequest request) {
|
||||||
try {
|
try {
|
||||||
Analyzer analyzer = createAnalyzer();
|
Analyzer analyzer = createAnalyzer();
|
||||||
|
String keyword = request == null ? null : request.getKeyword();
|
||||||
|
|
||||||
QueryParser titleQueryParser = new QueryParser("title", analyzer);
|
QueryParser titleQueryParser = new QueryParser("title", analyzer);
|
||||||
Query titleQuery = titleQueryParser.parse(keyword);
|
Query titleQuery = titleQueryParser.parse(keyword);
|
||||||
@@ -179,6 +185,9 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
||||||
builder.add(titleBooleanClause)
|
builder.add(titleBooleanClause)
|
||||||
.add(contentBooleanClause);
|
.add(contentBooleanClause);
|
||||||
|
if (request != null && request.getKnowledgeId() != null && !request.getKnowledgeId().trim().isEmpty()) {
|
||||||
|
builder.add(new TermQuery(new Term(KeywordSearchMetadataKeys.KNOWLEDGE_ID, request.getKnowledgeId().trim())), BooleanClause.Occur.MUST);
|
||||||
|
}
|
||||||
return builder.build();
|
return builder.build();
|
||||||
} catch (ParseException e) {
|
} catch (ParseException e) {
|
||||||
LOG.error(e.toString(), e);
|
LOG.error(e.toString(), e);
|
||||||
@@ -200,6 +209,16 @@ public class LuceneSearcher implements DocumentSearcher {
|
|||||||
return new JcsegAnalyzer(ISegment.Type.NLP, config, DictionaryFactory.createSingletonDictionary(config));
|
return new JcsegAnalyzer(ISegment.Type.NLP, config, DictionaryFactory.createSingletonDictionary(config));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void appendKnowledgeId(Document document, org.apache.lucene.document.Document luceneDoc) {
|
||||||
|
if (document == null || document.getMetadataMap() == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID);
|
||||||
|
if (knowledgeId != null) {
|
||||||
|
luceneDoc.add(new StringField(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId), Field.Store.YES));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void close(IndexWriter indexWriter) {
|
public void close(IndexWriter indexWriter) {
|
||||||
try {
|
try {
|
||||||
if (indexWriter != null) {
|
if (indexWriter != null) {
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
package com.easyagents.search.engine.lucene;
|
||||||
|
|
||||||
|
import com.easyagents.core.document.Document;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||||
|
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class LuceneSearcherTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldFilterByKnowledgeIdAndSearchTitleAndContent() throws Exception {
|
||||||
|
Path tempDir = Files.createTempDirectory("lucene-searcher-test");
|
||||||
|
LuceneConfig config = new LuceneConfig();
|
||||||
|
config.setIndexDirPath(tempDir.toString());
|
||||||
|
LuceneSearcher searcher = new LuceneSearcher(config);
|
||||||
|
|
||||||
|
Document first = new Document();
|
||||||
|
first.setId("1");
|
||||||
|
first.setTitle("客服标题");
|
||||||
|
first.setContent("这里没有关键字");
|
||||||
|
first.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100");
|
||||||
|
|
||||||
|
Document second = new Document();
|
||||||
|
second.setId("2");
|
||||||
|
second.setTitle("别的知识库");
|
||||||
|
second.setContent("客服内容");
|
||||||
|
second.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "200");
|
||||||
|
|
||||||
|
Assert.assertTrue(searcher.addDocument(first));
|
||||||
|
Assert.assertTrue(searcher.addDocument(second));
|
||||||
|
|
||||||
|
KeywordSearchRequest request = KeywordSearchRequest.of("客服", 10);
|
||||||
|
request.setKnowledgeId("100");
|
||||||
|
List<Document> results = searcher.searchDocuments(request);
|
||||||
|
|
||||||
|
Assert.assertEquals(1, results.size());
|
||||||
|
Assert.assertEquals("1", String.valueOf(results.get(0).getId()));
|
||||||
|
Assert.assertEquals("100", String.valueOf(results.get(0).getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID)));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,8 +28,12 @@ public interface DocumentSearcher {
|
|||||||
boolean updateDocument(Document document);
|
boolean updateDocument(Document document);
|
||||||
|
|
||||||
default List<Document> searchDocuments(String keyword) {
|
default List<Document> searchDocuments(String keyword) {
|
||||||
return searchDocuments(keyword, 10);
|
return searchDocuments(KeywordSearchRequest.of(keyword, 10));
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Document> searchDocuments(String keyword, int count);
|
default List<Document> searchDocuments(String keyword, int count) {
|
||||||
|
return searchDocuments(KeywordSearchRequest.of(keyword, count));
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Document> searchDocuments(KeywordSearchRequest request);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.easyagents.search.engine.service;
|
||||||
|
|
||||||
|
public final class KeywordSearchMetadataKeys {
|
||||||
|
|
||||||
|
private KeywordSearchMetadataKeys() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final String KNOWLEDGE_ID = "knowledgeId";
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package com.easyagents.search.engine.service;
|
||||||
|
|
||||||
|
public class KeywordSearchRequest {
|
||||||
|
|
||||||
|
private String keyword;
|
||||||
|
private int count = 10;
|
||||||
|
private String knowledgeId;
|
||||||
|
|
||||||
|
public static KeywordSearchRequest of(String keyword, int count) {
|
||||||
|
KeywordSearchRequest request = new KeywordSearchRequest();
|
||||||
|
request.setKeyword(keyword);
|
||||||
|
request.setCount(count);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getKeyword() {
|
||||||
|
return keyword;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setKeyword(String keyword) {
|
||||||
|
this.keyword = keyword;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getCount() {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCount(int count) {
|
||||||
|
this.count = count <= 0 ? 10 : count;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getKnowledgeId() {
|
||||||
|
return knowledgeId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setKnowledgeId(String knowledgeId) {
|
||||||
|
this.knowledgeId = knowledgeId;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -535,4 +535,15 @@ public class MilvusVectorStore extends DocumentStore {
|
|||||||
public MilvusClientV2 getClient() {
|
public MilvusClientV2 getClient() {
|
||||||
return client;
|
return client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean checkAvailable() {
|
||||||
|
try {
|
||||||
|
return client.hasCollection(HasCollectionReq.builder()
|
||||||
|
.collectionName("__milvus_boot_probe__")
|
||||||
|
.build()) != null;
|
||||||
|
} catch (Exception e) {
|
||||||
|
LOG.warn("Milvus availability check failed. message={}", e.getMessage());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user