知识库功能增强,支持Milvus,并优化相关逻辑
This commit is contained in:
@@ -15,6 +15,10 @@
|
||||
*/
|
||||
package com.easyagents.rerank;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.alibaba.fastjson2.JSONPath;
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.model.client.HttpClient;
|
||||
import com.easyagents.core.model.rerank.BaseRerankModel;
|
||||
@@ -22,15 +26,19 @@ import com.easyagents.core.model.rerank.RerankException;
|
||||
import com.easyagents.core.model.rerank.RerankOptions;
|
||||
import com.easyagents.core.util.Maps;
|
||||
import com.easyagents.core.util.StringUtil;
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.alibaba.fastjson2.JSONArray;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.alibaba.fastjson2.JSONPath;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class DefaultRerankModel extends BaseRerankModel<DefaultRerankModelConfig> {
|
||||
|
||||
private static final Logger LOG = LoggerFactory.getLogger(DefaultRerankModel.class);
|
||||
private static final int MAX_RESPONSE_LOG_LENGTH = 512;
|
||||
|
||||
private HttpClient httpClient = new HttpClient();
|
||||
|
||||
public DefaultRerankModel(DefaultRerankModelConfig config) {
|
||||
@@ -56,8 +64,18 @@ public class DefaultRerankModel extends BaseRerankModel<DefaultRerankModelConfig
|
||||
headers.put("Authorization", "Bearer " + config.getApiKey());
|
||||
|
||||
List<String> payloadDocuments = new ArrayList<>(documents.size());
|
||||
for (Document document : documents) {
|
||||
List<Integer> documentIndexMapping = new ArrayList<>(documents.size());
|
||||
for (int i = 0; i < documents.size(); i++) {
|
||||
Document document = documents.get(i);
|
||||
if (document == null || StringUtil.noText(document.getContent())) {
|
||||
continue;
|
||||
}
|
||||
payloadDocuments.add(document.getContent());
|
||||
documentIndexMapping.add(i);
|
||||
}
|
||||
|
||||
if (payloadDocuments.isEmpty()) {
|
||||
throw new RerankException("empty input documents");
|
||||
}
|
||||
|
||||
String payload = Maps.of("model", options.getModelOrDefault(config.getModel()))
|
||||
@@ -111,20 +129,69 @@ public class DefaultRerankModel extends BaseRerankModel<DefaultRerankModelConfig
|
||||
JSONArray results = (JSONArray) JSONPath.eval(jsonObject, config.getResultsJsonPath());
|
||||
|
||||
if (results == null || results.isEmpty()) {
|
||||
throw new RerankException("empty results");
|
||||
String error = extractErrorMessage(jsonObject);
|
||||
String detail = "empty results";
|
||||
if (StringUtil.hasText(error)) {
|
||||
detail = detail + ", error=" + error;
|
||||
}
|
||||
detail = detail + ", response=" + truncate(response, MAX_RESPONSE_LOG_LENGTH);
|
||||
LOG.warn("Rerank response has no results. query={}, response={}", query, truncate(response, MAX_RESPONSE_LOG_LENGTH));
|
||||
throw new RerankException(detail);
|
||||
}
|
||||
|
||||
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
JSONObject result = results.getJSONObject(i);
|
||||
int index = result.getIntValue(config.getIndexJsonKey());
|
||||
Document document = documents.get(index);
|
||||
document.setScore(result.getDoubleValue(config.getScoreJsonKey()));
|
||||
if (index < 0 || index >= documentIndexMapping.size()) {
|
||||
continue;
|
||||
}
|
||||
int originalIndex = documentIndexMapping.get(index);
|
||||
Document document = documents.get(originalIndex);
|
||||
if (document != null) {
|
||||
document.setScore(result.getDoubleValue(config.getScoreJsonKey()));
|
||||
}
|
||||
}
|
||||
|
||||
// 对 documents 排序, score 越大的越靠前
|
||||
documents.sort(Comparator.comparingDouble(Document::getScore).reversed());
|
||||
documents.sort((d1, d2) -> Double.compare(scoreOrMin(d2), scoreOrMin(d1)));
|
||||
|
||||
return documents;
|
||||
}
|
||||
|
||||
private double scoreOrMin(Document document) {
|
||||
if (document == null || document.getScore() == null) {
|
||||
return Double.NEGATIVE_INFINITY;
|
||||
}
|
||||
return document.getScore();
|
||||
}
|
||||
|
||||
private String extractErrorMessage(JSONObject jsonObject) {
|
||||
if (jsonObject == null || jsonObject.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
Object nestedMessage = JSONPath.eval(jsonObject, "$.error.message");
|
||||
if (nestedMessage != null && StringUtil.hasText(String.valueOf(nestedMessage))) {
|
||||
return String.valueOf(nestedMessage);
|
||||
}
|
||||
Object message = jsonObject.get("message");
|
||||
if (message != null && StringUtil.hasText(String.valueOf(message))) {
|
||||
return String.valueOf(message);
|
||||
}
|
||||
Object msg = jsonObject.get("msg");
|
||||
if (msg != null && StringUtil.hasText(String.valueOf(msg))) {
|
||||
return String.valueOf(msg);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private String truncate(String value, int maxLength) {
|
||||
if (value == null) {
|
||||
return null;
|
||||
}
|
||||
if (value.length() <= maxLength) {
|
||||
return value;
|
||||
}
|
||||
return value.substring(0, maxLength) + "...";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.easyagents.rerank;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.model.client.HttpClient;
|
||||
import com.easyagents.core.model.rerank.RerankException;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class DefaultRerankModelBehaviorTest {
|
||||
|
||||
@Test
|
||||
public void testRerankShouldThrowDetailedErrorWhenResultsEmpty() {
|
||||
DefaultRerankModel model = newModel();
|
||||
model.setHttpClient(new MockHttpClient("{\"error\":{\"message\":\"invalid documents\"}}"));
|
||||
|
||||
try {
|
||||
model.rerank("query", Arrays.asList(Document.of("doc-1")));
|
||||
Assert.fail("Expected RerankException");
|
||||
} catch (RerankException e) {
|
||||
Assert.assertTrue(e.getMessage().contains("empty results"));
|
||||
Assert.assertTrue(e.getMessage().contains("invalid documents"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRerankShouldSkipNullContentAndMapOriginalIndex() {
|
||||
DefaultRerankModel model = newModel();
|
||||
model.setHttpClient(new MockHttpClient("{\"results\":[{\"index\":0,\"relevance_score\":0.8},{\"index\":1,\"relevance_score\":0.3}]}"));
|
||||
|
||||
Document empty = new Document();
|
||||
Document d1 = Document.of("doc-1");
|
||||
Document d2 = Document.of("doc-2");
|
||||
List<Document> rerankResult = model.rerank("query", Arrays.asList(empty, d1, d2));
|
||||
|
||||
Assert.assertNull(empty.getScore());
|
||||
Assert.assertEquals(0.8d, d1.getScore(), 0.0001d);
|
||||
Assert.assertEquals(0.3d, d2.getScore(), 0.0001d);
|
||||
Assert.assertEquals("doc-1", rerankResult.get(0).getContent());
|
||||
}
|
||||
|
||||
private DefaultRerankModel newModel() {
|
||||
DefaultRerankModelConfig config = new DefaultRerankModelConfig();
|
||||
config.setEndpoint("https://example.com");
|
||||
config.setRequestPath("/v1/rerank");
|
||||
config.setModel("test-rerank-model");
|
||||
config.setApiKey("test-key");
|
||||
return new DefaultRerankModel(config);
|
||||
}
|
||||
|
||||
private static class MockHttpClient extends HttpClient {
|
||||
private final String response;
|
||||
|
||||
private MockHttpClient(String response) {
|
||||
this.response = response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String post(String url, Map<String, String> headers, String payload) {
|
||||
return response;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user