feat: 下沉知识库检索编排能力
- 新增 rag retrieval 核心协议、RRF 融合与相关度归一化 - 支持关键词检索按 knowledgeId 过滤并补充 ES/Lucene 单测 - 扩展 KnowledgeNode 检索模式与 Milvus 检索参数透传
This commit is contained in:
@@ -35,6 +35,7 @@ public class KnowledgeNode extends BaseNode {
|
||||
private Object knowledgeId;
|
||||
private String keyword;
|
||||
private String limit;
|
||||
private String retrievalMode = "HYBRID";
|
||||
|
||||
public Object getKnowledgeId() {
|
||||
return knowledgeId;
|
||||
@@ -60,6 +61,14 @@ public class KnowledgeNode extends BaseNode {
|
||||
this.limit = limit;
|
||||
}
|
||||
|
||||
public String getRetrievalMode() {
|
||||
return retrievalMode;
|
||||
}
|
||||
|
||||
public void setRetrievalMode(String retrievalMode) {
|
||||
this.retrievalMode = StringUtil.hasText(retrievalMode) ? retrievalMode : "HYBRID";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(Chain chain) {
|
||||
Map<String, Object> argsMap = chain.getState().resolveParameters(this);
|
||||
@@ -90,6 +99,7 @@ public class KnowledgeNode extends BaseNode {
|
||||
"knowledgeId=" + knowledgeId +
|
||||
", keyword='" + keyword + '\'' +
|
||||
", limit='" + limit + '\'' +
|
||||
", retrievalMode='" + retrievalMode + '\'' +
|
||||
", parameters=" + parameters +
|
||||
", outputDefs=" + outputDefs +
|
||||
", id='" + id + '\'' +
|
||||
|
||||
@@ -27,6 +27,7 @@ public class KnowledgeNodeParser extends BaseNodeParser<KnowledgeNode> {
|
||||
knowledgeNode.setKnowledgeId(data.get("knowledgeId"));
|
||||
knowledgeNode.setLimit(data.getString("limit"));
|
||||
knowledgeNode.setKeyword(data.getString("keyword"));
|
||||
knowledgeNode.setRetrievalMode(data.getString("retrievalMode"));
|
||||
|
||||
return knowledgeNode;
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-core</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-search-engine-service</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-rag-core</artifactId>
|
||||
@@ -32,5 +36,10 @@
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-rag-enhance</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface FusionStrategy {
|
||||
|
||||
List<RagHit> fuse(List<RagHit> vectorHits, List<RagHit> keywordHits, int topK);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
public enum HitSource {
|
||||
VECTOR,
|
||||
KEYWORD,
|
||||
BOTH;
|
||||
|
||||
public static HitSource merge(HitSource current, HitSource incoming) {
|
||||
if (current == null) {
|
||||
return incoming;
|
||||
}
|
||||
if (incoming == null || current == incoming) {
|
||||
return current;
|
||||
}
|
||||
return BOTH;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface KeywordRetriever {
|
||||
|
||||
List<RagHit> retrieve(RagQuery query);
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class RagHit {
|
||||
|
||||
private Object documentId;
|
||||
private String title;
|
||||
private String content;
|
||||
private Double score;
|
||||
private HitSource hitSource;
|
||||
private Double vectorScore;
|
||||
private Double keywordScore;
|
||||
private Integer rank;
|
||||
private Map<String, Object> metadata = new HashMap<String, Object>();
|
||||
|
||||
public static RagHit fromDocument(Document document) {
|
||||
if (document == null) {
|
||||
return null;
|
||||
}
|
||||
RagHit hit = new RagHit();
|
||||
hit.setDocumentId(document.getId());
|
||||
hit.setTitle(document.getTitle());
|
||||
hit.setContent(document.getContent());
|
||||
hit.setScore(document.getScore());
|
||||
hit.setMetadata(document.getMetadataMap());
|
||||
|
||||
Object hitSource = hit.getMetadata().get(RagRetrievalMetadataKeys.HIT_SOURCE);
|
||||
if (hitSource instanceof String) {
|
||||
hit.setHitSource(HitSource.valueOf(String.valueOf(hitSource)));
|
||||
}
|
||||
hit.setVectorScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.VECTOR_SCORE)));
|
||||
hit.setKeywordScore(asDouble(hit.getMetadata().get(RagRetrievalMetadataKeys.KEYWORD_SCORE)));
|
||||
hit.setRank(asInteger(hit.getMetadata().get(RagRetrievalMetadataKeys.FINAL_RANK)));
|
||||
return hit;
|
||||
}
|
||||
|
||||
public static RagHit fromDocument(Document document, HitSource defaultSource) {
|
||||
RagHit hit = fromDocument(document);
|
||||
if (hit != null && hit.getHitSource() == null) {
|
||||
hit.setHitSource(defaultSource);
|
||||
}
|
||||
if (hit != null && defaultSource == HitSource.VECTOR && hit.getVectorScore() == null) {
|
||||
hit.setVectorScore(document.getScore());
|
||||
}
|
||||
if (hit != null && defaultSource == HitSource.KEYWORD && hit.getKeywordScore() == null) {
|
||||
hit.setKeywordScore(document.getScore());
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
|
||||
public Document toDocument() {
|
||||
Document document = new Document();
|
||||
document.setId(documentId);
|
||||
document.setTitle(title);
|
||||
document.setContent(content);
|
||||
document.setScore(score);
|
||||
Map<String, Object> metadataMap = metadata == null
|
||||
? new HashMap<String, Object>()
|
||||
: new HashMap<String, Object>(metadata);
|
||||
if (hitSource != null) {
|
||||
metadataMap.put(RagRetrievalMetadataKeys.HIT_SOURCE, hitSource.name());
|
||||
}
|
||||
if (vectorScore != null) {
|
||||
metadataMap.put(RagRetrievalMetadataKeys.VECTOR_SCORE, vectorScore);
|
||||
}
|
||||
if (keywordScore != null) {
|
||||
metadataMap.put(RagRetrievalMetadataKeys.KEYWORD_SCORE, keywordScore);
|
||||
}
|
||||
if (!metadataMap.containsKey(RagRetrievalMetadataKeys.FUSION_SCORE)
|
||||
&& hitSource == HitSource.BOTH
|
||||
&& score != null) {
|
||||
metadataMap.put(RagRetrievalMetadataKeys.FUSION_SCORE, score);
|
||||
}
|
||||
if (rank != null) {
|
||||
metadataMap.put(RagRetrievalMetadataKeys.FINAL_RANK, rank);
|
||||
}
|
||||
document.setMetadataMap(metadataMap);
|
||||
return document;
|
||||
}
|
||||
|
||||
public RagHit copy() {
|
||||
RagHit copy = new RagHit();
|
||||
copy.setDocumentId(documentId);
|
||||
copy.setTitle(title);
|
||||
copy.setContent(content);
|
||||
copy.setScore(score);
|
||||
copy.setHitSource(hitSource);
|
||||
copy.setVectorScore(vectorScore);
|
||||
copy.setKeywordScore(keywordScore);
|
||||
copy.setRank(rank);
|
||||
copy.setMetadata(metadata);
|
||||
return copy;
|
||||
}
|
||||
|
||||
public Object getDocumentId() {
|
||||
return documentId;
|
||||
}
|
||||
|
||||
public void setDocumentId(Object documentId) {
|
||||
this.documentId = documentId;
|
||||
}
|
||||
|
||||
public String getTitle() {
|
||||
return title;
|
||||
}
|
||||
|
||||
public void setTitle(String title) {
|
||||
this.title = title;
|
||||
}
|
||||
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
|
||||
public void setContent(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public Double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public void setScore(Double score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public HitSource getHitSource() {
|
||||
return hitSource;
|
||||
}
|
||||
|
||||
public void setHitSource(HitSource hitSource) {
|
||||
this.hitSource = hitSource;
|
||||
}
|
||||
|
||||
public Double getVectorScore() {
|
||||
return vectorScore;
|
||||
}
|
||||
|
||||
public void setVectorScore(Double vectorScore) {
|
||||
this.vectorScore = vectorScore;
|
||||
}
|
||||
|
||||
public Double getKeywordScore() {
|
||||
return keywordScore;
|
||||
}
|
||||
|
||||
public void setKeywordScore(Double keywordScore) {
|
||||
this.keywordScore = keywordScore;
|
||||
}
|
||||
|
||||
public Integer getRank() {
|
||||
return rank;
|
||||
}
|
||||
|
||||
public void setRank(Integer rank) {
|
||||
this.rank = rank;
|
||||
}
|
||||
|
||||
public Map<String, Object> getMetadata() {
|
||||
return metadata;
|
||||
}
|
||||
|
||||
public void setMetadata(Map<String, Object> metadata) {
|
||||
this.metadata = metadata == null ? new HashMap<String, Object>() : new HashMap<String, Object>(metadata);
|
||||
}
|
||||
|
||||
private static Double asDouble(Object value) {
|
||||
if (value instanceof Number) {
|
||||
return ((Number) value).doubleValue();
|
||||
}
|
||||
if (value instanceof String) {
|
||||
try {
|
||||
return Double.valueOf((String) value);
|
||||
} catch (NumberFormatException ignore) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Integer asInteger(Object value) {
|
||||
if (value instanceof Number) {
|
||||
return ((Number) value).intValue();
|
||||
}
|
||||
if (value instanceof String) {
|
||||
try {
|
||||
return Integer.valueOf((String) value);
|
||||
} catch (NumberFormatException ignore) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class RagQuery {
|
||||
|
||||
private String query;
|
||||
private RetrievalMode retrievalMode = RetrievalMode.HYBRID;
|
||||
private Integer topK = 10;
|
||||
private Double minScore;
|
||||
private Map<String, Object> attributes = new HashMap<String, Object>();
|
||||
|
||||
public String getQuery() {
|
||||
return query;
|
||||
}
|
||||
|
||||
public void setQuery(String query) {
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
public RetrievalMode getRetrievalMode() {
|
||||
return retrievalMode;
|
||||
}
|
||||
|
||||
public void setRetrievalMode(RetrievalMode retrievalMode) {
|
||||
this.retrievalMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode;
|
||||
}
|
||||
|
||||
public Integer getTopK() {
|
||||
return topK;
|
||||
}
|
||||
|
||||
public void setTopK(Integer topK) {
|
||||
this.topK = topK == null || topK <= 0 ? 10 : topK;
|
||||
}
|
||||
|
||||
public Double getMinScore() {
|
||||
return minScore;
|
||||
}
|
||||
|
||||
public void setMinScore(Double minScore) {
|
||||
this.minScore = minScore;
|
||||
}
|
||||
|
||||
public Map<String, Object> getAttributes() {
|
||||
return attributes;
|
||||
}
|
||||
|
||||
public void setAttributes(Map<String, Object> attributes) {
|
||||
this.attributes = attributes == null ? new HashMap<String, Object>() : attributes;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.model.rerank.RerankModel;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public class RagRetrievalExecutor {
|
||||
|
||||
private final VectorRetriever vectorRetriever;
|
||||
private final KeywordRetriever keywordRetriever;
|
||||
private final FusionStrategy fusionStrategy;
|
||||
|
||||
public RagRetrievalExecutor(VectorRetriever vectorRetriever,
|
||||
KeywordRetriever keywordRetriever,
|
||||
FusionStrategy fusionStrategy) {
|
||||
this.vectorRetriever = vectorRetriever;
|
||||
this.keywordRetriever = keywordRetriever;
|
||||
this.fusionStrategy = fusionStrategy == null ? new RrfFusionStrategy() : fusionStrategy;
|
||||
}
|
||||
|
||||
public RagRetrievalResult retrieve(RagQuery query) {
|
||||
RagQuery effectiveQuery = query == null ? new RagQuery() : query;
|
||||
RetrievalMode mode = effectiveQuery.getRetrievalMode() == null
|
||||
? RetrievalMode.HYBRID
|
||||
: effectiveQuery.getRetrievalMode();
|
||||
if (mode == RetrievalMode.VECTOR) {
|
||||
return buildResult(normalizeHits(vectorRetriever == null
|
||||
? Collections.<RagHit>emptyList()
|
||||
: vectorRetriever.retrieve(effectiveQuery), HitSource.VECTOR), effectiveQuery.getTopK());
|
||||
}
|
||||
if (mode == RetrievalMode.KEYWORD) {
|
||||
return buildResult(normalizeHits(keywordRetriever == null
|
||||
? Collections.<RagHit>emptyList()
|
||||
: keywordRetriever.retrieve(effectiveQuery), HitSource.KEYWORD), effectiveQuery.getTopK());
|
||||
}
|
||||
|
||||
CompletableFuture<List<RagHit>> vectorFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier<List<RagHit>>() {
|
||||
@Override
|
||||
public List<RagHit> get() {
|
||||
return vectorRetriever == null ? Collections.<RagHit>emptyList() : vectorRetriever.retrieve(effectiveQuery);
|
||||
}
|
||||
});
|
||||
CompletableFuture<List<RagHit>> keywordFuture = CompletableFuture.supplyAsync(new java.util.function.Supplier<List<RagHit>>() {
|
||||
@Override
|
||||
public List<RagHit> get() {
|
||||
return keywordRetriever == null ? Collections.<RagHit>emptyList() : keywordRetriever.retrieve(effectiveQuery);
|
||||
}
|
||||
});
|
||||
|
||||
List<RagHit> vectorHits = normalizeHits(vectorFuture.join(), HitSource.VECTOR);
|
||||
List<RagHit> keywordHits = normalizeHits(keywordFuture.join(), HitSource.KEYWORD);
|
||||
return buildResult(fusionStrategy.fuse(vectorHits, keywordHits, effectiveQuery.getTopK()), effectiveQuery.getTopK());
|
||||
}
|
||||
|
||||
public RagRetrievalResult rerank(String query, List<RagHit> hits, RerankModel rerankModel, int topK) {
|
||||
if (rerankModel == null || hits == null || hits.isEmpty()) {
|
||||
return buildResult(hits, topK);
|
||||
}
|
||||
List<Document> rerankInput = new ArrayList<Document>(hits.size());
|
||||
Map<String, RagHit> hitMap = new LinkedHashMap<String, RagHit>();
|
||||
for (RagHit hit : hits) {
|
||||
if (hit == null || hit.getDocumentId() == null) {
|
||||
continue;
|
||||
}
|
||||
rerankInput.add(hit.toDocument());
|
||||
hitMap.put(String.valueOf(hit.getDocumentId()), hit.copy());
|
||||
}
|
||||
List<Document> rerankedDocuments = rerankModel.rerank(query, rerankInput);
|
||||
List<RagHit> rerankedHits = new ArrayList<RagHit>();
|
||||
if (rerankedDocuments != null) {
|
||||
for (Document rerankedDocument : rerankedDocuments) {
|
||||
if (rerankedDocument == null || rerankedDocument.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
RagHit original = hitMap.get(String.valueOf(rerankedDocument.getId()));
|
||||
if (original == null) {
|
||||
continue;
|
||||
}
|
||||
original.setContent(rerankedDocument.getContent());
|
||||
original.setTitle(rerankedDocument.getTitle());
|
||||
original.setScore(rerankedDocument.getScore());
|
||||
if (rerankedDocument.getScore() != null) {
|
||||
original.getMetadata().put(RagRetrievalMetadataKeys.RERANK_SCORE, rerankedDocument.getScore());
|
||||
}
|
||||
rerankedHits.add(original);
|
||||
}
|
||||
}
|
||||
sortByScore(rerankedHits);
|
||||
assignRank(rerankedHits);
|
||||
return buildResult(rerankedHits, topK);
|
||||
}
|
||||
|
||||
private RagRetrievalResult buildResult(List<RagHit> hits, int topK) {
|
||||
List<RagHit> effectiveHits = hits == null ? new ArrayList<RagHit>() : new ArrayList<RagHit>(hits);
|
||||
if (topK > 0 && effectiveHits.size() > topK) {
|
||||
effectiveHits = new ArrayList<RagHit>(effectiveHits.subList(0, topK));
|
||||
}
|
||||
assignRank(effectiveHits);
|
||||
RagRetrievalResult result = new RagRetrievalResult();
|
||||
result.setHits(effectiveHits);
|
||||
result.setTotal(effectiveHits.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
private List<RagHit> normalizeHits(List<RagHit> hits, HitSource defaultSource) {
|
||||
List<RagHit> normalized = new ArrayList<RagHit>();
|
||||
if (hits == null) {
|
||||
return normalized;
|
||||
}
|
||||
for (RagHit hit : hits) {
|
||||
if (hit == null || hit.getDocumentId() == null) {
|
||||
continue;
|
||||
}
|
||||
RagHit copy = hit.copy();
|
||||
if (copy.getHitSource() == null) {
|
||||
copy.setHitSource(defaultSource);
|
||||
}
|
||||
if (defaultSource == HitSource.VECTOR && copy.getVectorScore() == null) {
|
||||
copy.setVectorScore(copy.getScore());
|
||||
}
|
||||
if (defaultSource == HitSource.KEYWORD && copy.getKeywordScore() == null) {
|
||||
copy.setKeywordScore(copy.getScore());
|
||||
}
|
||||
normalized.add(copy);
|
||||
}
|
||||
assignRank(normalized);
|
||||
return normalized;
|
||||
}
|
||||
|
||||
private void sortByScore(List<RagHit> hits) {
|
||||
Collections.sort(hits, new Comparator<RagHit>() {
|
||||
@Override
|
||||
public int compare(RagHit left, RagHit right) {
|
||||
return Double.compare(right.getScore() == null ? 0D : right.getScore(),
|
||||
left.getScore() == null ? 0D : left.getScore());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void assignRank(List<RagHit> hits) {
|
||||
for (int i = 0; i < hits.size(); i++) {
|
||||
RagHit hit = hits.get(i);
|
||||
hit.setRank(i + 1);
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1);
|
||||
if (hit.getHitSource() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource().name());
|
||||
}
|
||||
if (hit.getVectorScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore());
|
||||
}
|
||||
if (hit.getKeywordScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore());
|
||||
}
|
||||
if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||
&& hit.getHitSource() == HitSource.BOTH
|
||||
&& hit.getScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
public final class RagRetrievalMetadataKeys {
|
||||
|
||||
private RagRetrievalMetadataKeys() {
|
||||
}
|
||||
|
||||
public static final String HIT_SOURCE = "hitSource";
|
||||
public static final String VECTOR_SCORE = "vectorScore";
|
||||
public static final String KEYWORD_SCORE = "keywordScore";
|
||||
public static final String FUSION_SCORE = "fusionScore";
|
||||
public static final String RERANK_SCORE = "rerankScore";
|
||||
public static final String VECTOR_RANK = "vectorRank";
|
||||
public static final String KEYWORD_RANK = "keywordRank";
|
||||
public static final String FINAL_RANK = "rank";
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class RagRetrievalResult {
|
||||
|
||||
private List<RagHit> hits = new ArrayList<RagHit>();
|
||||
private Integer total;
|
||||
|
||||
public List<RagHit> getHits() {
|
||||
return hits;
|
||||
}
|
||||
|
||||
public void setHits(List<RagHit> hits) {
|
||||
this.hits = hits == null ? new ArrayList<RagHit>() : hits;
|
||||
}
|
||||
|
||||
public Integer getTotal() {
|
||||
return total;
|
||||
}
|
||||
|
||||
public void setTotal(Integer total) {
|
||||
this.total = total;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public final class RagScoreNormalizer {
|
||||
|
||||
private RagScoreNormalizer() {
|
||||
}
|
||||
|
||||
public static void normalize(List<Document> documents, RetrievalMode retrievalMode, boolean reranked) {
|
||||
if (documents == null || documents.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (reranked) {
|
||||
normalizeRerankScores(documents);
|
||||
return;
|
||||
}
|
||||
|
||||
RetrievalMode effectiveMode = retrievalMode == null ? RetrievalMode.HYBRID : retrievalMode;
|
||||
if (effectiveMode == RetrievalMode.VECTOR) {
|
||||
for (Document document : documents) {
|
||||
document.setScore(clamp01(readRawScore(document, RagRetrievalMetadataKeys.VECTOR_SCORE, document == null ? null : document.getScore())));
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (effectiveMode == RetrievalMode.KEYWORD) {
|
||||
for (Document document : documents) {
|
||||
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.KEYWORD_SCORE, document == null ? null : document.getScore());
|
||||
if (rawScore == null || rawScore <= 0D) {
|
||||
document.setScore(0D);
|
||||
} else {
|
||||
document.setScore(clamp01(rawScore / (1D + rawScore)));
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
double maxHybridRrfScore = 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D);
|
||||
for (Document document : documents) {
|
||||
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.FUSION_SCORE, document == null ? null : document.getScore());
|
||||
if (rawScore == null || rawScore <= 0D) {
|
||||
document.setScore(0D);
|
||||
} else {
|
||||
document.setScore(clamp01(rawScore / maxHybridRrfScore));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void normalizeRerankScores(List<Document> documents) {
|
||||
List<Double> rawScores = new ArrayList<Double>(documents.size());
|
||||
boolean allPresent = true;
|
||||
Double min = null;
|
||||
Double max = null;
|
||||
for (Document document : documents) {
|
||||
Double rawScore = readRawScore(document, RagRetrievalMetadataKeys.RERANK_SCORE, document == null ? null : document.getScore());
|
||||
rawScores.add(rawScore);
|
||||
if (rawScore == null) {
|
||||
allPresent = false;
|
||||
continue;
|
||||
}
|
||||
min = min == null ? rawScore : Math.min(min, rawScore);
|
||||
max = max == null ? rawScore : Math.max(max, rawScore);
|
||||
}
|
||||
|
||||
if (allPresent && min != null && max != null && Double.compare(max, min) != 0) {
|
||||
for (int i = 0; i < documents.size(); i++) {
|
||||
Double rawScore = rawScores.get(i);
|
||||
documents.get(i).setScore(clamp01((rawScore - min) / (max - min)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (documents.size() == 1) {
|
||||
documents.get(0).setScore(1D);
|
||||
return;
|
||||
}
|
||||
|
||||
int size = documents.size();
|
||||
for (int i = 0; i < size; i++) {
|
||||
documents.get(i).setScore(clamp01(1D - ((double) i / (double) (size - 1))));
|
||||
}
|
||||
}
|
||||
|
||||
private static Double readRawScore(Document document, String metadataKey, Double fallback) {
|
||||
if (document == null) {
|
||||
return null;
|
||||
}
|
||||
Object metadataValue = document.getMetadata(metadataKey);
|
||||
if (metadataValue instanceof Number) {
|
||||
return ((Number) metadataValue).doubleValue();
|
||||
}
|
||||
if (metadataValue instanceof String) {
|
||||
try {
|
||||
return Double.valueOf((String) metadataValue);
|
||||
} catch (NumberFormatException ignore) {
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
private static double clamp01(Double value) {
|
||||
if (value == null || value.isNaN() || value.isInfinite()) {
|
||||
return 0D;
|
||||
}
|
||||
if (value < 0D) {
|
||||
return 0D;
|
||||
}
|
||||
if (value > 1D) {
|
||||
return 1D;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
public enum RetrievalMode {
|
||||
VECTOR,
|
||||
KEYWORD,
|
||||
HYBRID;
|
||||
|
||||
public static RetrievalMode from(String value) {
|
||||
if (value == null || value.trim().isEmpty()) {
|
||||
return HYBRID;
|
||||
}
|
||||
for (RetrievalMode mode : values()) {
|
||||
if (mode.name().equalsIgnoreCase(value.trim())) {
|
||||
return mode;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported retrieval mode: " + value);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class RrfFusionStrategy implements FusionStrategy {
|
||||
|
||||
public static final int DEFAULT_RRF_K = 60;
|
||||
|
||||
private final int rrfK;
|
||||
|
||||
public RrfFusionStrategy() {
|
||||
this(DEFAULT_RRF_K);
|
||||
}
|
||||
|
||||
public RrfFusionStrategy(int rrfK) {
|
||||
this.rrfK = rrfK <= 0 ? DEFAULT_RRF_K : rrfK;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RagHit> fuse(List<RagHit> vectorHits, List<RagHit> keywordHits, int topK) {
|
||||
Map<String, RagHit> mergedHits = new LinkedHashMap<String, RagHit>();
|
||||
Map<String, Double> fusionScores = new LinkedHashMap<String, Double>();
|
||||
|
||||
mergeHits(mergedHits, fusionScores, vectorHits, HitSource.VECTOR);
|
||||
mergeHits(mergedHits, fusionScores, keywordHits, HitSource.KEYWORD);
|
||||
|
||||
List<RagHit> results = new ArrayList<RagHit>(mergedHits.values());
|
||||
Collections.sort(results, new Comparator<RagHit>() {
|
||||
@Override
|
||||
public int compare(RagHit left, RagHit right) {
|
||||
int scoreCompare = Double.compare(nullSafeScore(right.getScore()), nullSafeScore(left.getScore()));
|
||||
if (scoreCompare != 0) {
|
||||
return scoreCompare;
|
||||
}
|
||||
return String.valueOf(left.getDocumentId()).compareTo(String.valueOf(right.getDocumentId()));
|
||||
}
|
||||
});
|
||||
assignRank(results);
|
||||
if (topK > 0 && results.size() > topK) {
|
||||
return new ArrayList<RagHit>(results.subList(0, topK));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private void mergeHits(Map<String, RagHit> mergedHits,
|
||||
Map<String, Double> fusionScores,
|
||||
List<RagHit> incomingHits,
|
||||
HitSource source) {
|
||||
if (incomingHits == null) {
|
||||
return;
|
||||
}
|
||||
for (int index = 0; index < incomingHits.size(); index++) {
|
||||
RagHit incoming = incomingHits.get(index);
|
||||
if (incoming == null || incoming.getDocumentId() == null) {
|
||||
continue;
|
||||
}
|
||||
String key = String.valueOf(incoming.getDocumentId());
|
||||
RagHit merged = mergedHits.get(key);
|
||||
if (merged == null) {
|
||||
merged = incoming.copy();
|
||||
merged.setHitSource(source);
|
||||
mergedHits.put(key, merged);
|
||||
fusionScores.put(key, 0D);
|
||||
} else {
|
||||
mergeBaseFields(merged, incoming);
|
||||
merged.setHitSource(HitSource.merge(merged.getHitSource(), source));
|
||||
}
|
||||
|
||||
int rank = index + 1;
|
||||
if (source == HitSource.VECTOR) {
|
||||
merged.setVectorScore(incoming.getVectorScore() != null ? incoming.getVectorScore() : incoming.getScore());
|
||||
merged.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_RANK, rank);
|
||||
} else {
|
||||
merged.setKeywordScore(incoming.getKeywordScore() != null ? incoming.getKeywordScore() : incoming.getScore());
|
||||
merged.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_RANK, rank);
|
||||
}
|
||||
double fusedScore = fusionScores.get(key) + (1D / (rrfK + rank));
|
||||
fusionScores.put(key, fusedScore);
|
||||
merged.setScore(fusedScore);
|
||||
}
|
||||
}
|
||||
|
||||
private void mergeBaseFields(RagHit target, RagHit incoming) {
|
||||
if ((target.getTitle() == null || target.getTitle().isEmpty()) && incoming.getTitle() != null) {
|
||||
target.setTitle(incoming.getTitle());
|
||||
}
|
||||
if ((target.getContent() == null || target.getContent().isEmpty()) && incoming.getContent() != null) {
|
||||
target.setContent(incoming.getContent());
|
||||
}
|
||||
if (incoming.getMetadata() != null && !incoming.getMetadata().isEmpty()) {
|
||||
target.getMetadata().putAll(incoming.getMetadata());
|
||||
}
|
||||
}
|
||||
|
||||
private void assignRank(List<RagHit> results) {
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
RagHit hit = results.get(i);
|
||||
hit.setRank(i + 1);
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.FINAL_RANK, i + 1);
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.HIT_SOURCE, hit.getHitSource() == null ? null : hit.getHitSource().name());
|
||||
if (!hit.getMetadata().containsKey(RagRetrievalMetadataKeys.RERANK_SCORE) && hit.getScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.FUSION_SCORE, hit.getScore());
|
||||
}
|
||||
if (hit.getVectorScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.VECTOR_SCORE, hit.getVectorScore());
|
||||
}
|
||||
if (hit.getKeywordScore() != null) {
|
||||
hit.getMetadata().put(RagRetrievalMetadataKeys.KEYWORD_SCORE, hit.getKeywordScore());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private double nullSafeScore(Double value) {
|
||||
return value == null ? 0D : value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface VectorRetriever {
|
||||
|
||||
List<RagHit> retrieve(RagQuery query);
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.model.rerank.RerankModel;
|
||||
import com.easyagents.core.model.rerank.RerankOptions;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class RagRetrievalExecutorTest {
|
||||
|
||||
@Test
|
||||
public void shouldOnlyCallVectorRetrieverInVectorMode() {
|
||||
final List<String> invoked = new ArrayList<String>();
|
||||
RagRetrievalExecutor executor = new RagRetrievalExecutor(
|
||||
new VectorRetriever() {
|
||||
@Override
|
||||
public List<RagHit> retrieve(RagQuery query) {
|
||||
invoked.add("vector");
|
||||
return Arrays.asList(hit("1", "vector", 0.9D, HitSource.VECTOR));
|
||||
}
|
||||
},
|
||||
new KeywordRetriever() {
|
||||
@Override
|
||||
public List<RagHit> retrieve(RagQuery query) {
|
||||
invoked.add("keyword");
|
||||
return Arrays.asList(hit("2", "keyword", 5D, HitSource.KEYWORD));
|
||||
}
|
||||
},
|
||||
new RrfFusionStrategy()
|
||||
);
|
||||
|
||||
RagQuery query = new RagQuery();
|
||||
query.setRetrievalMode(RetrievalMode.VECTOR);
|
||||
query.setTopK(10);
|
||||
|
||||
RagRetrievalResult result = executor.retrieve(query);
|
||||
|
||||
Assert.assertEquals(Arrays.asList("vector"), invoked);
|
||||
Assert.assertEquals(1, result.getHits().size());
|
||||
Assert.assertEquals(HitSource.VECTOR, result.getHits().get(0).getHitSource());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldRerankAfterHybridFusionAndKeepHitSource() {
|
||||
RagRetrievalExecutor executor = new RagRetrievalExecutor(
|
||||
new VectorRetriever() {
|
||||
@Override
|
||||
public List<RagHit> retrieve(RagQuery query) {
|
||||
return Arrays.asList(
|
||||
hit("1", "alpha", 0.9D, HitSource.VECTOR),
|
||||
hit("2", "beta", 0.8D, HitSource.VECTOR)
|
||||
);
|
||||
}
|
||||
},
|
||||
new KeywordRetriever() {
|
||||
@Override
|
||||
public List<RagHit> retrieve(RagQuery query) {
|
||||
return Arrays.asList(
|
||||
hit("2", "beta", 4.1D, HitSource.KEYWORD),
|
||||
hit("3", "gamma", 3.9D, HitSource.KEYWORD)
|
||||
);
|
||||
}
|
||||
},
|
||||
new RrfFusionStrategy()
|
||||
);
|
||||
|
||||
RagQuery query = new RagQuery();
|
||||
query.setRetrievalMode(RetrievalMode.HYBRID);
|
||||
query.setTopK(10);
|
||||
|
||||
RagRetrievalResult retrieved = executor.retrieve(query);
|
||||
RagRetrievalResult reranked = executor.rerank("query", retrieved.getHits(), new ReverseRerankModel(), 10);
|
||||
|
||||
Assert.assertEquals("3", String.valueOf(reranked.getHits().get(0).getDocumentId()));
|
||||
Assert.assertEquals("2", String.valueOf(reranked.getHits().get(2).getDocumentId()));
|
||||
Assert.assertEquals(HitSource.BOTH, reranked.getHits().get(2).getHitSource());
|
||||
}
|
||||
|
||||
private RagHit hit(String id, String content, double score, HitSource source) {
|
||||
RagHit hit = new RagHit();
|
||||
hit.setDocumentId(id);
|
||||
hit.setContent(content);
|
||||
hit.setScore(score);
|
||||
hit.setHitSource(source);
|
||||
if (source == HitSource.VECTOR) {
|
||||
hit.setVectorScore(score);
|
||||
} else if (source == HitSource.KEYWORD) {
|
||||
hit.setKeywordScore(score);
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
|
||||
private static class ReverseRerankModel implements RerankModel {
|
||||
|
||||
@Override
|
||||
public List<Document> rerank(String query, List<Document> documents, RerankOptions options) {
|
||||
List<Document> result = new ArrayList<Document>();
|
||||
double score = documents.size();
|
||||
for (int i = documents.size() - 1; i >= 0; i--) {
|
||||
Document document = documents.get(i);
|
||||
document.setScore(score--);
|
||||
result.add(document);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class RagScoreNormalizerTest {
|
||||
|
||||
@Test
|
||||
public void shouldNormalizeKeywordScoresToZeroAndOneRange() {
|
||||
Document first = document(1, 9D, RagRetrievalMetadataKeys.KEYWORD_SCORE);
|
||||
Document second = document(2, 0D, RagRetrievalMetadataKeys.KEYWORD_SCORE);
|
||||
|
||||
RagScoreNormalizer.normalize(Arrays.asList(first, second), RetrievalMode.KEYWORD, false);
|
||||
|
||||
Assert.assertEquals(0.9D, first.getScore(), 0.0001D);
|
||||
Assert.assertEquals(0D, second.getScore(), 0.0001D);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNormalizeHybridFusionScoreByRrfUpperBound() {
|
||||
Document document = document(1, 2D / (RrfFusionStrategy.DEFAULT_RRF_K + 1D), RagRetrievalMetadataKeys.FUSION_SCORE);
|
||||
|
||||
RagScoreNormalizer.normalize(Arrays.asList(document), RetrievalMode.HYBRID, false);
|
||||
|
||||
Assert.assertEquals(1D, document.getScore(), 0.0001D);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNormalizeRerankScoresByMinMax() {
|
||||
List<Document> documents = Arrays.asList(
|
||||
document(1, 10D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||
document(2, 20D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||
document(3, 30D, RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||
);
|
||||
|
||||
RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true);
|
||||
|
||||
Assert.assertEquals(0D, documents.get(0).getScore(), 0.0001D);
|
||||
Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D);
|
||||
Assert.assertEquals(1D, documents.get(2).getScore(), 0.0001D);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldFallbackToRankBasedNormalizationWhenRerankScoresAreEqual() {
|
||||
List<Document> documents = Arrays.asList(
|
||||
document(1, 5D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||
document(2, 5D, RagRetrievalMetadataKeys.RERANK_SCORE),
|
||||
document(3, 5D, RagRetrievalMetadataKeys.RERANK_SCORE)
|
||||
);
|
||||
|
||||
RagScoreNormalizer.normalize(documents, RetrievalMode.HYBRID, true);
|
||||
|
||||
Assert.assertEquals(1D, documents.get(0).getScore(), 0.0001D);
|
||||
Assert.assertEquals(0.5D, documents.get(1).getScore(), 0.0001D);
|
||||
Assert.assertEquals(0D, documents.get(2).getScore(), 0.0001D);
|
||||
}
|
||||
|
||||
private Document document(Object id, Double score, String metadataKey) {
|
||||
Document document = new Document();
|
||||
document.setId(id);
|
||||
document.setScore(score);
|
||||
document.addMetadata(metadataKey, score);
|
||||
return document;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.easyagents.rag.retrieval;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class RrfFusionStrategyTest {
|
||||
|
||||
@Test
|
||||
public void shouldMarkDuplicateHitsAsBothAndRankFirst() {
|
||||
RrfFusionStrategy strategy = new RrfFusionStrategy();
|
||||
|
||||
RagHit vectorOnly = hit("1", 0.91D, HitSource.VECTOR);
|
||||
RagHit bothVector = hit("2", 0.88D, HitSource.VECTOR);
|
||||
RagHit bothKeyword = hit("2", 5.2D, HitSource.KEYWORD);
|
||||
RagHit keywordOnly = hit("3", 4.8D, HitSource.KEYWORD);
|
||||
|
||||
List<RagHit> fused = strategy.fuse(
|
||||
Arrays.asList(vectorOnly, bothVector),
|
||||
Arrays.asList(bothKeyword, keywordOnly),
|
||||
10
|
||||
);
|
||||
|
||||
Assert.assertEquals(3, fused.size());
|
||||
Assert.assertEquals("2", String.valueOf(fused.get(0).getDocumentId()));
|
||||
Assert.assertEquals(HitSource.BOTH, fused.get(0).getHitSource());
|
||||
Assert.assertEquals("1", String.valueOf(fused.get(1).getDocumentId()));
|
||||
Assert.assertEquals("3", String.valueOf(fused.get(2).getDocumentId()));
|
||||
}
|
||||
|
||||
private RagHit hit(String id, double score, HitSource source) {
|
||||
RagHit hit = new RagHit();
|
||||
hit.setDocumentId(id);
|
||||
hit.setContent("doc-" + id);
|
||||
hit.setScore(score);
|
||||
hit.setHitSource(source);
|
||||
if (source == HitSource.VECTOR) {
|
||||
hit.setVectorScore(score);
|
||||
} else if (source == HitSource.KEYWORD) {
|
||||
hit.setKeywordScore(score);
|
||||
}
|
||||
return hit;
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,11 @@
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.15.2</version> <!-- 或与Elasticsearch客户端兼容的版本 -->
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
||||
@@ -4,12 +4,15 @@ import co.elastic.clients.elasticsearch.ElasticsearchClient;
|
||||
import co.elastic.clients.elasticsearch.core.*;
|
||||
import co.elastic.clients.elasticsearch.core.bulk.BulkOperation;
|
||||
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation;
|
||||
import co.elastic.clients.elasticsearch.core.search.SourceConfig;
|
||||
import co.elastic.clients.json.JsonData;
|
||||
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
|
||||
import co.elastic.clients.transport.ElasticsearchTransport;
|
||||
import co.elastic.clients.transport.rest_client.RestClientTransport;
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.search.engine.service.DocumentSearcher;
|
||||
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.auth.AuthScope;
|
||||
import org.apache.http.auth.UsernamePasswordCredentials;
|
||||
@@ -88,13 +91,7 @@ public class ElasticSearcher implements DocumentSearcher {
|
||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
source.put("id", document.getId());
|
||||
source.put("content", document.getContent());
|
||||
if (document.getTitle() != null) {
|
||||
source.put("title", document.getTitle());
|
||||
}
|
||||
|
||||
Map<String, Object> source = buildSource(document);
|
||||
String documentId = document.getId().toString();
|
||||
IndexOperation<?> indexOp = IndexOperation.of(i -> i
|
||||
.index(esConfig.getIndexName())
|
||||
@@ -116,7 +113,7 @@ public class ElasticSearcher implements DocumentSearcher {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Document> searchDocuments(String keyword, int count) {
|
||||
public List<Document> searchDocuments(KeywordSearchRequest request) {
|
||||
RestClient restClient = null;
|
||||
ElasticsearchTransport transport = null;
|
||||
|
||||
@@ -125,21 +122,16 @@ public class ElasticSearcher implements DocumentSearcher {
|
||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||
|
||||
SearchRequest request = SearchRequest.of(s -> s
|
||||
.index(esConfig.getIndexName())
|
||||
.size(count)
|
||||
.query(q -> q
|
||||
.match(m -> m
|
||||
.field("title")
|
||||
.field("content")
|
||||
.query(keyword)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
SearchResponse<Document> response = client.search(request, Document.class);
|
||||
SearchResponse<Map> response = client.search(buildSearchRequest(request), Map.class);
|
||||
List<Document> results = new ArrayList<>();
|
||||
response.hits().hits().forEach(hit -> results.add(hit.source()));
|
||||
response.hits().hits().forEach(hit -> {
|
||||
Map source = hit.source();
|
||||
Document document = toDocument(hit.id(), source, hit.score());
|
||||
if (document == null) {
|
||||
return;
|
||||
}
|
||||
results.add(document);
|
||||
});
|
||||
return results;
|
||||
|
||||
} catch (Exception e) {
|
||||
@@ -193,14 +185,17 @@ public class ElasticSearcher implements DocumentSearcher {
|
||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||
|
||||
UpdateRequest<Document, Object> request = UpdateRequest.of(u -> u
|
||||
UpdateRequest<Map<String, Object>, Map<String, Object>> request = UpdateRequest.of(u -> u
|
||||
.index(esConfig.getIndexName())
|
||||
.id(document.getId().toString())
|
||||
.doc(document)
|
||||
.doc(buildSource(document))
|
||||
);
|
||||
|
||||
UpdateResponse<Document> response = client.update(request, Object.class);
|
||||
return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated;
|
||||
@SuppressWarnings("unchecked")
|
||||
Class<Map<String, Object>> documentClass = (Class<Map<String, Object>>) (Class<?>) Map.class;
|
||||
UpdateResponse<Map<String, Object>> response = client.update(request, documentClass);
|
||||
return response.result() == co.elastic.clients.elasticsearch._types.Result.Updated
|
||||
|| response.result() == co.elastic.clients.elasticsearch._types.Result.NoOp;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Error updating document with id: " + document.getId(), e);
|
||||
return false;
|
||||
@@ -220,4 +215,88 @@ public class ElasticSearcher implements DocumentSearcher {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Document toDocument(String hitId, Map source, Double score) {
|
||||
if (source == null || source.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Document document = new Document();
|
||||
Object id = source.get("id");
|
||||
document.setId(id != null ? id : hitId);
|
||||
|
||||
Object title = source.get("title");
|
||||
if (title != null) {
|
||||
document.setTitle(String.valueOf(title));
|
||||
}
|
||||
|
||||
Object content = source.get("content");
|
||||
if (content != null) {
|
||||
document.setContent(String.valueOf(content));
|
||||
}
|
||||
|
||||
Object metadataMap = source.get("metadataMap");
|
||||
if (metadataMap instanceof Map<?, ?>) {
|
||||
document.setMetadataMap(new HashMap<>((Map<String, Object>) metadataMap));
|
||||
}
|
||||
|
||||
document.setScore(score);
|
||||
return document;
|
||||
}
|
||||
|
||||
Map<String, Object> buildSource(Document document) {
|
||||
Map<String, Object> source = new HashMap<String, Object>();
|
||||
source.put("id", document.getId());
|
||||
source.put("content", document.getContent());
|
||||
if (document.getTitle() != null) {
|
||||
source.put("title", document.getTitle());
|
||||
}
|
||||
if (document.getMetadataMap() != null && !document.getMetadataMap().isEmpty()) {
|
||||
source.put("metadataMap", new HashMap<String, Object>(document.getMetadataMap()));
|
||||
Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID);
|
||||
if (knowledgeId != null) {
|
||||
source.put(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId));
|
||||
}
|
||||
}
|
||||
return source;
|
||||
}
|
||||
|
||||
SearchRequest buildSearchRequest(KeywordSearchRequest request) {
|
||||
KeywordSearchRequest effectiveRequest = request == null ? new KeywordSearchRequest() : request;
|
||||
return SearchRequest.of(s -> s
|
||||
.index(esConfig.getIndexName())
|
||||
.size(effectiveRequest.getCount())
|
||||
.source(SourceConfig.of(sc -> sc.filter(f -> f.includes("id", "title", "content", "metadataMap"))))
|
||||
.query(q -> q.bool(b -> {
|
||||
b.must(m -> m.multiMatch(mm -> mm
|
||||
.query(effectiveRequest.getKeyword())
|
||||
.fields("title", "content")
|
||||
));
|
||||
if (effectiveRequest.getKnowledgeId() != null && !effectiveRequest.getKnowledgeId().trim().isEmpty()) {
|
||||
b.filter(f -> f.term(t -> t
|
||||
.field(KeywordSearchMetadataKeys.KNOWLEDGE_ID)
|
||||
.value(v -> v.stringValue(effectiveRequest.getKnowledgeId().trim()))
|
||||
));
|
||||
}
|
||||
return b;
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
public boolean checkAvailable() {
|
||||
RestClient restClient = null;
|
||||
ElasticsearchTransport transport = null;
|
||||
try {
|
||||
restClient = buildRestClient();
|
||||
transport = new RestClientTransport(restClient, new JacksonJsonpMapper());
|
||||
ElasticsearchClient client = new ElasticsearchClient(transport);
|
||||
return client.info() != null;
|
||||
} catch (Exception e) {
|
||||
LOG.error("Elasticsearch availability check failed", e);
|
||||
return false;
|
||||
} finally {
|
||||
closeResources(transport, restClient);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.easyagents.engine.es;
|
||||
|
||||
import co.elastic.clients.elasticsearch.core.SearchRequest;
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class ElasticSearcherQueryBuilderTest {
|
||||
|
||||
@Test
|
||||
public void shouldBuildSearchRequestWithMultiMatchAndKnowledgeFilter() {
|
||||
ElasticSearcher searcher = new ElasticSearcher(config());
|
||||
KeywordSearchRequest request = KeywordSearchRequest.of("客服", 5);
|
||||
request.setKnowledgeId("100");
|
||||
|
||||
SearchRequest searchRequest = searcher.buildSearchRequest(request);
|
||||
|
||||
Assert.assertEquals(5, searchRequest.size().intValue());
|
||||
Assert.assertNotNull(searchRequest.query().bool());
|
||||
Assert.assertEquals(1, searchRequest.query().bool().must().size());
|
||||
Assert.assertNotNull(searchRequest.query().bool().must().get(0).multiMatch());
|
||||
Assert.assertEquals(2, searchRequest.query().bool().must().get(0).multiMatch().fields().size());
|
||||
Assert.assertEquals(1, searchRequest.query().bool().filter().size());
|
||||
Assert.assertEquals("knowledgeId", searchRequest.query().bool().filter().get(0).term().field());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldExtractKnowledgeIdToTopLevelSource() {
|
||||
ElasticSearcher searcher = new ElasticSearcher(config());
|
||||
Document document = new Document();
|
||||
document.setId("1");
|
||||
document.setTitle("title");
|
||||
document.setContent("content");
|
||||
document.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100");
|
||||
|
||||
Map<String, Object> source = searcher.buildSource(document);
|
||||
|
||||
Assert.assertEquals("100", source.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID));
|
||||
Assert.assertTrue(source.get("metadataMap") instanceof Map);
|
||||
}
|
||||
|
||||
private ESConfig config() {
|
||||
ESConfig config = new ESConfig();
|
||||
config.setHost("http://127.0.0.1:9200");
|
||||
config.setUserName("elastic");
|
||||
config.setPassword("elastic");
|
||||
config.setIndexName("easyflow");
|
||||
return config;
|
||||
}
|
||||
}
|
||||
@@ -51,5 +51,10 @@
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-search-engine-service</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@@ -17,6 +17,8 @@ package com.easyagents.search.engine.lucene;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.search.engine.service.DocumentSearcher;
|
||||
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.StringField;
|
||||
@@ -78,7 +80,7 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
if (document.getTitle() != null) {
|
||||
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
||||
}
|
||||
|
||||
appendKnowledgeId(document, luceneDoc);
|
||||
|
||||
indexWriter.addDocument(luceneDoc);
|
||||
indexWriter.commit();
|
||||
@@ -127,7 +129,7 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
if (document.getTitle() != null) {
|
||||
luceneDoc.add(new TextField("title", document.getTitle(), Field.Store.YES));
|
||||
}
|
||||
|
||||
appendKnowledgeId(document, luceneDoc);
|
||||
indexWriter.updateDocument(term, luceneDoc);
|
||||
indexWriter.commit();
|
||||
return true;
|
||||
@@ -140,18 +142,21 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Document> searchDocuments(String keyword, int count) {
|
||||
public List<Document> searchDocuments(KeywordSearchRequest request) {
|
||||
List<Document> results = new ArrayList<>();
|
||||
try (IndexReader reader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
Query query = buildQuery(keyword);
|
||||
TopDocs topDocs = searcher.search(query, count);
|
||||
Query query = buildQuery(request);
|
||||
TopDocs topDocs = searcher.search(query, request == null ? 10 : request.getCount());
|
||||
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
|
||||
org.apache.lucene.document.Document doc = searcher.doc(scoreDoc.doc);
|
||||
Document resultDoc = new Document();
|
||||
resultDoc.setId(doc.get("id"));
|
||||
resultDoc.setContent(doc.get("content"));
|
||||
resultDoc.setTitle(doc.get("title"));
|
||||
if (doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID) != null) {
|
||||
resultDoc.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, doc.get(KeywordSearchMetadataKeys.KNOWLEDGE_ID));
|
||||
}
|
||||
|
||||
resultDoc.setScore((double) scoreDoc.score);
|
||||
|
||||
@@ -164,9 +169,10 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
return results;
|
||||
}
|
||||
|
||||
private static Query buildQuery(String keyword) {
|
||||
Query buildQuery(KeywordSearchRequest request) {
|
||||
try {
|
||||
Analyzer analyzer = createAnalyzer();
|
||||
String keyword = request == null ? null : request.getKeyword();
|
||||
|
||||
QueryParser titleQueryParser = new QueryParser("title", analyzer);
|
||||
Query titleQuery = titleQueryParser.parse(keyword);
|
||||
@@ -179,6 +185,9 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
||||
builder.add(titleBooleanClause)
|
||||
.add(contentBooleanClause);
|
||||
if (request != null && request.getKnowledgeId() != null && !request.getKnowledgeId().trim().isEmpty()) {
|
||||
builder.add(new TermQuery(new Term(KeywordSearchMetadataKeys.KNOWLEDGE_ID, request.getKnowledgeId().trim())), BooleanClause.Occur.MUST);
|
||||
}
|
||||
return builder.build();
|
||||
} catch (ParseException e) {
|
||||
LOG.error(e.toString(), e);
|
||||
@@ -200,6 +209,16 @@ public class LuceneSearcher implements DocumentSearcher {
|
||||
return new JcsegAnalyzer(ISegment.Type.NLP, config, DictionaryFactory.createSingletonDictionary(config));
|
||||
}
|
||||
|
||||
private void appendKnowledgeId(Document document, org.apache.lucene.document.Document luceneDoc) {
|
||||
if (document == null || document.getMetadataMap() == null) {
|
||||
return;
|
||||
}
|
||||
Object knowledgeId = document.getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID);
|
||||
if (knowledgeId != null) {
|
||||
luceneDoc.add(new StringField(KeywordSearchMetadataKeys.KNOWLEDGE_ID, String.valueOf(knowledgeId), Field.Store.YES));
|
||||
}
|
||||
}
|
||||
|
||||
public void close(IndexWriter indexWriter) {
|
||||
try {
|
||||
if (indexWriter != null) {
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package com.easyagents.search.engine.lucene;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.search.engine.service.KeywordSearchMetadataKeys;
|
||||
import com.easyagents.search.engine.service.KeywordSearchRequest;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.List;
|
||||
|
||||
public class LuceneSearcherTest {
|
||||
|
||||
@Test
|
||||
public void shouldFilterByKnowledgeIdAndSearchTitleAndContent() throws Exception {
|
||||
Path tempDir = Files.createTempDirectory("lucene-searcher-test");
|
||||
LuceneConfig config = new LuceneConfig();
|
||||
config.setIndexDirPath(tempDir.toString());
|
||||
LuceneSearcher searcher = new LuceneSearcher(config);
|
||||
|
||||
Document first = new Document();
|
||||
first.setId("1");
|
||||
first.setTitle("客服标题");
|
||||
first.setContent("这里没有关键字");
|
||||
first.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "100");
|
||||
|
||||
Document second = new Document();
|
||||
second.setId("2");
|
||||
second.setTitle("别的知识库");
|
||||
second.setContent("客服内容");
|
||||
second.addMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID, "200");
|
||||
|
||||
Assert.assertTrue(searcher.addDocument(first));
|
||||
Assert.assertTrue(searcher.addDocument(second));
|
||||
|
||||
KeywordSearchRequest request = KeywordSearchRequest.of("客服", 10);
|
||||
request.setKnowledgeId("100");
|
||||
List<Document> results = searcher.searchDocuments(request);
|
||||
|
||||
Assert.assertEquals(1, results.size());
|
||||
Assert.assertEquals("1", String.valueOf(results.get(0).getId()));
|
||||
Assert.assertEquals("100", String.valueOf(results.get(0).getMetadata(KeywordSearchMetadataKeys.KNOWLEDGE_ID)));
|
||||
}
|
||||
}
|
||||
@@ -28,8 +28,12 @@ public interface DocumentSearcher {
|
||||
boolean updateDocument(Document document);
|
||||
|
||||
default List<Document> searchDocuments(String keyword) {
|
||||
return searchDocuments(keyword, 10);
|
||||
return searchDocuments(KeywordSearchRequest.of(keyword, 10));
|
||||
}
|
||||
|
||||
List<Document> searchDocuments(String keyword, int count);
|
||||
default List<Document> searchDocuments(String keyword, int count) {
|
||||
return searchDocuments(KeywordSearchRequest.of(keyword, count));
|
||||
}
|
||||
|
||||
List<Document> searchDocuments(KeywordSearchRequest request);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.easyagents.search.engine.service;
|
||||
|
||||
public final class KeywordSearchMetadataKeys {
|
||||
|
||||
private KeywordSearchMetadataKeys() {
|
||||
}
|
||||
|
||||
public static final String KNOWLEDGE_ID = "knowledgeId";
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.easyagents.search.engine.service;
|
||||
|
||||
public class KeywordSearchRequest {
|
||||
|
||||
private String keyword;
|
||||
private int count = 10;
|
||||
private String knowledgeId;
|
||||
|
||||
public static KeywordSearchRequest of(String keyword, int count) {
|
||||
KeywordSearchRequest request = new KeywordSearchRequest();
|
||||
request.setKeyword(keyword);
|
||||
request.setCount(count);
|
||||
return request;
|
||||
}
|
||||
|
||||
public String getKeyword() {
|
||||
return keyword;
|
||||
}
|
||||
|
||||
public void setKeyword(String keyword) {
|
||||
this.keyword = keyword;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
public void setCount(int count) {
|
||||
this.count = count <= 0 ? 10 : count;
|
||||
}
|
||||
|
||||
public String getKnowledgeId() {
|
||||
return knowledgeId;
|
||||
}
|
||||
|
||||
public void setKnowledgeId(String knowledgeId) {
|
||||
this.knowledgeId = knowledgeId;
|
||||
}
|
||||
}
|
||||
@@ -535,4 +535,15 @@ public class MilvusVectorStore extends DocumentStore {
|
||||
public MilvusClientV2 getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
public boolean checkAvailable() {
|
||||
try {
|
||||
return client.hasCollection(HasCollectionReq.builder()
|
||||
.collectionName("__milvus_boot_probe__")
|
||||
.build()) != null;
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Milvus availability check failed. message={}", e.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user