初始化
This commit is contained in:
47
easy-agents-store/easy-agents-store-chroma/pom.xml
Normal file
47
easy-agents-store/easy-agents-store-chroma/pom.xml
Normal file
@@ -0,0 +1,47 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-store</artifactId>
|
||||
<version>${revision}</version>
|
||||
</parent>
|
||||
|
||||
<name>easy-agents-store-chroma</name>
|
||||
<artifactId>easy-agents-store-chroma</artifactId>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.easyagents</groupId>
|
||||
<artifactId>easy-agents-core</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!-- <dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.10.1</version>
|
||||
</dependency> -->
|
||||
|
||||
<!-- 测试依赖 -->
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
/*
|
||||
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.easyagents.store.chroma;
|
||||
|
||||
import com.easyagents.core.store.condition.Condition;
|
||||
import com.easyagents.core.store.condition.ConditionType;
|
||||
import com.easyagents.core.store.condition.ExpressionAdaptor;
|
||||
import com.easyagents.core.store.condition.Value;
|
||||
|
||||
import java.util.StringJoiner;
|
||||
|
||||
public class ChromaExpressionAdaptor implements ExpressionAdaptor {
|
||||
|
||||
public static final ChromaExpressionAdaptor DEFAULT = new ChromaExpressionAdaptor();
|
||||
|
||||
@Override
|
||||
public String toOperationSymbol(ConditionType type) {
|
||||
if (type == ConditionType.EQ) {
|
||||
return " == ";
|
||||
} else if (type == ConditionType.NE) {
|
||||
return " != ";
|
||||
} else if (type == ConditionType.GT) {
|
||||
return " > ";
|
||||
} else if (type == ConditionType.GE) {
|
||||
return " >= ";
|
||||
} else if (type == ConditionType.LT) {
|
||||
return " < ";
|
||||
} else if (type == ConditionType.LE) {
|
||||
return " <= ";
|
||||
} else if (type == ConditionType.IN) {
|
||||
return " IN ";
|
||||
}
|
||||
return type.getDefaultSymbol();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toCondition(Condition condition) {
|
||||
if (condition.getType() == ConditionType.BETWEEN) {
|
||||
Object[] values = (Object[]) ((Value) condition.getRight()).getValue();
|
||||
return "(" + toLeft(condition.getLeft())
|
||||
+ toOperationSymbol(ConditionType.GE)
|
||||
+ values[0] + " && "
|
||||
+ toLeft(condition.getLeft())
|
||||
+ toOperationSymbol(ConditionType.LE)
|
||||
+ values[1] + ")";
|
||||
}
|
||||
|
||||
return ExpressionAdaptor.super.toCondition(condition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toValue(Condition condition, Object value) {
|
||||
if (value == null) {
|
||||
return "null";
|
||||
}
|
||||
|
||||
if (condition.getType() == ConditionType.IN) {
|
||||
Object[] values = (Object[]) value;
|
||||
StringJoiner stringJoiner = new StringJoiner(",", "[", "]");
|
||||
for (Object v : values) {
|
||||
if (v != null) {
|
||||
stringJoiner.add("\"" + v + "\"");
|
||||
}
|
||||
}
|
||||
return stringJoiner.toString();
|
||||
} else if (value instanceof String) {
|
||||
return "\"" + value + "\"";
|
||||
} else if (value instanceof Boolean) {
|
||||
return ((Boolean) value).toString();
|
||||
} else if (value instanceof Number) {
|
||||
return value.toString();
|
||||
}
|
||||
|
||||
return ExpressionAdaptor.super.toValue(condition, value);
|
||||
}
|
||||
|
||||
public String toLeft(Object left) {
|
||||
if (left instanceof String) {
|
||||
String field = (String) left;
|
||||
if (field.contains(".")) {
|
||||
return field;
|
||||
}
|
||||
return field;
|
||||
}
|
||||
return left.toString();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,794 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.easyagents.store.chroma;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.store.DocumentStore;
|
||||
import com.easyagents.core.store.SearchWrapper;
|
||||
import com.easyagents.core.store.StoreOptions;
|
||||
import com.easyagents.core.store.StoreResult;
|
||||
import com.easyagents.core.store.condition.ExpressionAdaptor;
|
||||
import com.easyagents.core.model.client.HttpClient;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* ChromaVectorStore class provides an interface to interact with Chroma Vector Database
|
||||
* using direct HTTP calls to the Chroma REST API.
|
||||
*/
|
||||
public class ChromaVectorStore extends DocumentStore {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ChromaVectorStore.class);
|
||||
private final String baseUrl;
|
||||
private final String collectionName;
|
||||
private final String tenant;
|
||||
private final String database;
|
||||
private final ChromaVectorStoreConfig config;
|
||||
private final ExpressionAdaptor expressionAdaptor;
|
||||
private final HttpClient httpClient;
|
||||
private final int MAX_RETRIES = 3;
|
||||
private final long RETRY_INTERVAL_MS = 1000;
|
||||
|
||||
private static final String BASE_API = "/api/v2";
|
||||
|
||||
public ChromaVectorStore(ChromaVectorStoreConfig config) {
|
||||
Objects.requireNonNull(config, "ChromaVectorStoreConfig cannot be null");
|
||||
this.baseUrl = config.getBaseUrl();
|
||||
this.tenant = config.getTenant();
|
||||
this.database = config.getDatabase();
|
||||
this.collectionName = config.getCollectionName();
|
||||
this.config = config;
|
||||
this.expressionAdaptor = ChromaExpressionAdaptor.DEFAULT;
|
||||
|
||||
// 创建并配置HttpClient实例
|
||||
this.httpClient = createHttpClient();
|
||||
|
||||
// 验证配置的有效性
|
||||
validateConfig();
|
||||
|
||||
// 如果配置了自动创建集合,检查并创建集合
|
||||
if (config.isAutoCreateCollection()) {
|
||||
try {
|
||||
// 确保租户和数据库存在
|
||||
ensureTenantAndDatabaseExists();
|
||||
// 确保集合存在
|
||||
ensureCollectionExists();
|
||||
} catch (Exception e) {
|
||||
logger.warn("Failed to ensure collection exists: {}. Will retry on first operation.", e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private HttpClient createHttpClient() {
|
||||
HttpClient client = new HttpClient();
|
||||
return client;
|
||||
}
|
||||
|
||||
private void validateConfig() {
|
||||
if (baseUrl == null || baseUrl.isEmpty()) {
|
||||
throw new IllegalArgumentException("Base URL cannot be empty");
|
||||
}
|
||||
|
||||
if (!baseUrl.startsWith("http://") && !baseUrl.startsWith("https://")) {
|
||||
throw new IllegalArgumentException("Base URL must start with http:// or https://");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 确保租户和数据库存在,如果不存在则创建
|
||||
*/
|
||||
private void ensureTenantAndDatabaseExists() {
|
||||
try {
|
||||
// 检查并创建租户
|
||||
if (tenant != null && !tenant.isEmpty()) {
|
||||
ensureTenantExists();
|
||||
|
||||
// 检查并创建数据库(如果租户已设置)
|
||||
if (database != null && !database.isEmpty()) {
|
||||
ensureDatabaseExists();
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error("Error ensuring tenant and database exist", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 确保租户存在,如果不存在则创建
|
||||
*/
|
||||
private void ensureTenantExists() throws IOException {
|
||||
String tenantUrl = baseUrl + BASE_API + "/tenants/" + tenant;
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
try {
|
||||
// 尝试获取租户信息
|
||||
String responseBody = executeWithRetry(() -> httpClient.get(tenantUrl, headers));
|
||||
logger.debug("Successfully verified tenant '{}' exists", tenant);
|
||||
} catch (IOException e) {
|
||||
// 如果获取失败,尝试创建租户
|
||||
logger.info("Creating tenant '{}' as it does not exist", tenant);
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("name", tenant);
|
||||
|
||||
String createTenantUrl = baseUrl + BASE_API + "/tenants";
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(createTenantUrl, headers, jsonRequestBody));
|
||||
logger.info("Successfully created tenant '{}'", tenant);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 确保数据库存在,如果不存在则创建
|
||||
*/
|
||||
private void ensureDatabaseExists() throws IOException {
|
||||
if (tenant == null || tenant.isEmpty()) {
|
||||
throw new IllegalStateException("Cannot create database without tenant");
|
||||
}
|
||||
|
||||
String databaseUrl = baseUrl + BASE_API + "/tenants/" + tenant + "/databases/" + database;
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
try {
|
||||
// 尝试获取数据库信息
|
||||
String responseBody = executeWithRetry(() -> httpClient.get(databaseUrl, headers));
|
||||
logger.debug("Successfully verified database '{}' exists in tenant '{}'",
|
||||
database, tenant);
|
||||
} catch (IOException e) {
|
||||
// 如果获取失败,尝试创建数据库
|
||||
logger.info("Creating database '{}' in tenant '{}' as it does not exist",
|
||||
database, tenant);
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("name", database);
|
||||
|
||||
String createDatabaseUrl = baseUrl + BASE_API + "/tenants/" + tenant + "/databases";
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(createDatabaseUrl, headers, jsonRequestBody));
|
||||
logger.info("Successfully created database '{}' in tenant '{}'",
|
||||
database, tenant);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据collectionName查询Collection ID
|
||||
*/
|
||||
private String getCollectionId(String collectionName) throws IOException {
|
||||
String collectionsUrl = buildCollectionsUrl();
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.get(collectionsUrl, headers));
|
||||
if (responseBody == null) {
|
||||
throw new IOException("Failed to get collections, no response");
|
||||
}
|
||||
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
List<Map<String, Object>> collections = new ArrayList<>();
|
||||
|
||||
// 处理不同格式的响应
|
||||
if (responseObj instanceof Map) {
|
||||
Map<String, Object> responseMap = (Map<String, Object>) responseObj;
|
||||
if (responseMap.containsKey("collections") && responseMap.get("collections") instanceof List) {
|
||||
collections = (List<Map<String, Object>>) responseMap.get("collections");
|
||||
}
|
||||
} else if (responseObj instanceof List) {
|
||||
List<?> rawCollections = (List<?>) responseObj;
|
||||
for (Object item : rawCollections) {
|
||||
if (item instanceof Map) {
|
||||
collections.add((Map<String, Object>) item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找指定名称的集合
|
||||
for (Map<String, Object> collection : collections) {
|
||||
if (collection.containsKey("name") && collectionName.equals(collection.get("name"))) {
|
||||
return collection.get("id").toString();
|
||||
}
|
||||
}
|
||||
|
||||
throw new IOException("Collection not found: " + collectionName);
|
||||
}
|
||||
|
||||
private void createCollection() throws IOException {
|
||||
// 构建创建集合的API URL,包含tenant和database
|
||||
String createCollectionUrl = buildCollectionsUrl();
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("name", collectionName);
|
||||
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(createCollectionUrl, headers, jsonRequestBody));
|
||||
if (responseBody == null) {
|
||||
throw new IOException("Failed to create collection: no response");
|
||||
}
|
||||
|
||||
try {
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
|
||||
Map<String, Object> responseMap = null;
|
||||
if (responseObj instanceof Map) {
|
||||
responseMap = (Map<String, Object>) responseObj;
|
||||
}
|
||||
if (responseMap.containsKey("error")) {
|
||||
throw new IOException("Failed to create collection: " + responseMap.get("error"));
|
||||
}
|
||||
|
||||
logger.info("Collection '{}' created successfully", collectionName);
|
||||
} catch (Exception e) {
|
||||
throw new IOException("Failed to process collection creation response: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public StoreResult doStore(List<Document> documents, StoreOptions options) {
|
||||
Objects.requireNonNull(documents, "Documents cannot be null");
|
||||
|
||||
if (documents.isEmpty()) {
|
||||
logger.debug("No documents to store");
|
||||
return StoreResult.success();
|
||||
}
|
||||
|
||||
try {
|
||||
// 确保集合存在
|
||||
ensureCollectionExists();
|
||||
|
||||
String collectionName = getCollectionName(options);
|
||||
|
||||
List<String> ids = new ArrayList<>();
|
||||
List<List<Double>> embeddings = new ArrayList<>();
|
||||
List<Map<String, Object>> metadatas = new ArrayList<>();
|
||||
List<String> documentsContent = new ArrayList<>();
|
||||
|
||||
for (Document doc : documents) {
|
||||
ids.add(String.valueOf(doc.getId()));
|
||||
|
||||
if (doc.getVector() != null) {
|
||||
List<Double> embedding = doc.getVectorAsDoubleList();
|
||||
embeddings.add(embedding);
|
||||
} else {
|
||||
embeddings.add(null);
|
||||
}
|
||||
|
||||
Map<String, Object> metadata = doc.getMetadataMap() != null ?
|
||||
new HashMap<>(doc.getMetadataMap()) : new HashMap<>();
|
||||
metadatas.add(metadata);
|
||||
|
||||
documentsContent.add(doc.getContent());
|
||||
}
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("ids", ids);
|
||||
requestBody.put("embeddings", embeddings);
|
||||
requestBody.put("metadatas", metadatas);
|
||||
requestBody.put("documents", documentsContent);
|
||||
|
||||
String collectionId = getCollectionId(collectionName);
|
||||
|
||||
// 构建包含tenant和database的完整URL
|
||||
String collectionUrl = buildCollectionUrl(collectionId, "add");
|
||||
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
logger.debug("Storing {} documents to collection '{}'", documents.size(), collectionName);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
|
||||
if (responseBody == null) {
|
||||
logger.error("Error storing documents: no response");
|
||||
return StoreResult.fail();
|
||||
}
|
||||
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
|
||||
Map<String, Object> responseMap = null;
|
||||
if (responseObj instanceof Map) {
|
||||
responseMap = (Map<String, Object>) responseObj;
|
||||
}
|
||||
if (responseMap.containsKey("error")) {
|
||||
String errorMsg = "Error storing documents: " + responseMap.get("error");
|
||||
logger.error(errorMsg);
|
||||
return StoreResult.fail();
|
||||
}
|
||||
|
||||
logger.debug("Successfully stored {} documents", documents.size());
|
||||
return StoreResult.successWithIds(documents);
|
||||
} catch (Exception e) {
|
||||
logger.error("Error storing documents to Chroma", e);
|
||||
return StoreResult.fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public StoreResult doDelete(Collection<?> ids, StoreOptions options) {
|
||||
Objects.requireNonNull(ids, "IDs cannot be null");
|
||||
|
||||
if (ids.isEmpty()) {
|
||||
logger.debug("No IDs to delete");
|
||||
return StoreResult.success();
|
||||
}
|
||||
|
||||
try {
|
||||
// 确保集合存在
|
||||
ensureCollectionExists();
|
||||
|
||||
String collectionName = getCollectionName(options);
|
||||
|
||||
List<String> stringIds = ids.stream()
|
||||
.map(Object::toString)
|
||||
.collect(Collectors.toList());
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
requestBody.put("ids", stringIds);
|
||||
|
||||
String collectionId = getCollectionId(collectionName);
|
||||
|
||||
// 构建包含tenant和database的完整URL
|
||||
String collectionUrl = buildCollectionUrl(collectionId, "delete");
|
||||
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
logger.debug("Deleting {} documents from collection '{}'", ids.size(), collectionName);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
|
||||
if (responseBody == null) {
|
||||
logger.error("Error deleting documents: no response");
|
||||
return StoreResult.fail();
|
||||
}
|
||||
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
|
||||
Map<String, Object> responseMap = null;
|
||||
if (responseObj instanceof Map) {
|
||||
responseMap = (Map<String, Object>) responseObj;
|
||||
}
|
||||
if (responseMap.containsKey("error")) {
|
||||
String errorMsg = "Error deleting documents: " + responseMap.get("error");
|
||||
logger.error(errorMsg);
|
||||
return StoreResult.fail();
|
||||
}
|
||||
|
||||
logger.debug("Successfully deleted {} documents", ids.size());
|
||||
return StoreResult.success();
|
||||
} catch (Exception e) {
|
||||
logger.error("Error deleting documents from Chroma", e);
|
||||
return StoreResult.fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public StoreResult doUpdate(List<Document> documents, StoreOptions options) {
|
||||
Objects.requireNonNull(documents, "Documents cannot be null");
|
||||
|
||||
if (documents.isEmpty()) {
|
||||
logger.debug("No documents to update");
|
||||
return StoreResult.success();
|
||||
}
|
||||
|
||||
try {
|
||||
// Chroma doesn't support direct update, so we delete and re-add
|
||||
List<Object> ids = documents.stream().map(Document::getId).collect(Collectors.toList());
|
||||
StoreResult deleteResult = doDelete(ids, options);
|
||||
|
||||
if (!deleteResult.isSuccess()) {
|
||||
logger.warn("Delete failed during update operation: {}", deleteResult.toString());
|
||||
// 尝试继续添加,因为可能有些文档是新的
|
||||
}
|
||||
|
||||
StoreResult storeResult = doStore(documents, options);
|
||||
|
||||
if (storeResult.isSuccess()) {
|
||||
logger.debug("Successfully updated {} documents", documents.size());
|
||||
}
|
||||
|
||||
return storeResult;
|
||||
} catch (Exception e) {
|
||||
logger.error("Error updating documents in Chroma", e);
|
||||
return StoreResult.fail();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Document> doSearch(SearchWrapper wrapper, StoreOptions options) {
|
||||
Objects.requireNonNull(wrapper, "SearchWrapper cannot be null");
|
||||
|
||||
try {
|
||||
// 确保集合存在
|
||||
ensureCollectionExists();
|
||||
|
||||
String collectionName = getCollectionName(options);
|
||||
|
||||
int limit = wrapper.getMaxResults() > 0 ? wrapper.getMaxResults() : 10;
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
// 检查查询条件是否有效
|
||||
if (wrapper.getVector() == null && wrapper.getText() == null) {
|
||||
throw new IllegalArgumentException("Either vector or text must be provided for search");
|
||||
}
|
||||
|
||||
// 设置查询向量
|
||||
if (wrapper.getVector() != null) {
|
||||
List<Double> queryEmbedding = wrapper.getVectorAsDoubleList();
|
||||
requestBody.put("query_embeddings", Collections.singletonList(queryEmbedding));
|
||||
logger.debug("Performing vector search with dimension: {}", queryEmbedding.size());
|
||||
} else if (wrapper.getText() != null) {
|
||||
requestBody.put("query_texts", Collections.singletonList(wrapper.getText()));
|
||||
logger.debug("Performing text search: {}", sanitizeLogString(wrapper.getText(), 100));
|
||||
}
|
||||
|
||||
// 设置返回数量
|
||||
requestBody.put("n_results", limit);
|
||||
|
||||
// 设置过滤条件
|
||||
if (wrapper.getCondition() != null) {
|
||||
try {
|
||||
String whereClause = expressionAdaptor.toCondition(wrapper.getCondition());
|
||||
// Chroma的where条件是JSON对象,需要解析
|
||||
Object whereObj = parseJsonResponse(whereClause);
|
||||
|
||||
Map<String, Object> whereMap = null;
|
||||
if (whereObj instanceof Map) {
|
||||
whereMap = (Map<String, Object>) whereObj;
|
||||
}
|
||||
requestBody.put("where", whereMap);
|
||||
logger.debug("Search with filter condition: {}", whereClause);
|
||||
} catch (Exception e) {
|
||||
logger.warn("Failed to parse filter condition: {}, ignoring condition", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
String collectionId = getCollectionId(collectionName);
|
||||
|
||||
// 构建包含tenant和database的完整URL
|
||||
String collectionUrl = buildCollectionUrl(collectionId, "query");
|
||||
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
|
||||
if (responseBody == null) {
|
||||
logger.error("Error searching documents: no response");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
|
||||
Map<String, Object> responseMap = null;
|
||||
if (responseObj instanceof Map) {
|
||||
responseMap = (Map<String, Object>) responseObj;
|
||||
}
|
||||
|
||||
// 检查响应是否包含error字段
|
||||
if (responseMap.containsKey("error")) {
|
||||
logger.error("Error searching documents: {}", responseMap.get("error"));
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
return parseSearchResults(responseMap);
|
||||
} catch (Exception e) {
|
||||
logger.error("Error searching documents in Chroma", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 支持直接使用向量数组和topK参数的搜索方法
|
||||
*/
|
||||
public List<Document> searchInternal(double[] vector, int topK, StoreOptions options) {
|
||||
Objects.requireNonNull(vector, "Vector cannot be null");
|
||||
|
||||
if (topK <= 0) {
|
||||
topK = 10;
|
||||
}
|
||||
|
||||
try {
|
||||
// 确保集合存在
|
||||
ensureCollectionExists();
|
||||
|
||||
String collectionName = getCollectionName(options);
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
|
||||
// 设置查询向量
|
||||
List<Double> queryEmbedding = Arrays.stream(vector)
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
requestBody.put("query_embeddings", Collections.singletonList(queryEmbedding));
|
||||
|
||||
// 设置返回数量
|
||||
requestBody.put("n_results", topK);
|
||||
|
||||
String collectionId = getCollectionId(collectionName);
|
||||
|
||||
// 构建包含tenant和database的完整URL
|
||||
String collectionUrl = buildCollectionUrl(collectionId, "query");
|
||||
|
||||
Map<String, String> headers = createHeaders();
|
||||
|
||||
String jsonRequestBody = safeJsonSerialize(requestBody);
|
||||
|
||||
logger.debug("Performing direct vector search with dimension: {}", vector.length);
|
||||
|
||||
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
|
||||
if (responseBody == null) {
|
||||
logger.error("Error searching documents: no response");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
Object responseObj = parseJsonResponse(responseBody);
|
||||
|
||||
Map<String, Object> responseMap = null;
|
||||
if (responseObj instanceof Map) {
|
||||
responseMap = (Map<String, Object>) responseObj;
|
||||
}
|
||||
|
||||
// 检查响应是否包含error字段
|
||||
if (responseMap.containsKey("error")) {
|
||||
logger.error("Error searching documents: {}", responseMap.get("error"));
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 解析结果
|
||||
return parseSearchResults(responseMap);
|
||||
} catch (Exception e) {
|
||||
logger.error("Error searching documents in Chroma", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
private List<Document> parseSearchResults(Map<String, Object> responseMap) {
|
||||
try {
|
||||
List<String> ids = extractResultsFromNestedList(responseMap, "ids");
|
||||
List<String> documents = extractResultsFromNestedList(responseMap, "documents");
|
||||
List<Map<String, Object>> metadatas = extractResultsFromNestedList(responseMap, "metadatas");
|
||||
List<List<Double>> embeddings = extractResultsFromNestedList(responseMap, "embeddings");
|
||||
List<Double> distances = extractResultsFromNestedList(responseMap, "distances");
|
||||
|
||||
if (ids == null || ids.isEmpty()) {
|
||||
logger.debug("No documents found in search results");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
// 转换为Easy-Agents的Document格式
|
||||
List<Document> resultDocs = new ArrayList<>();
|
||||
for (int i = 0; i < ids.size(); i++) {
|
||||
Document doc = new Document();
|
||||
doc.setId(ids.get(i));
|
||||
|
||||
if (documents != null && i < documents.size()) {
|
||||
doc.setContent(documents.get(i));
|
||||
}
|
||||
|
||||
if (metadatas != null && i < metadatas.size()) {
|
||||
doc.setMetadataMap(metadatas.get(i));
|
||||
}
|
||||
|
||||
if (embeddings != null && i < embeddings.size() && embeddings.get(i) != null) {
|
||||
doc.setVector(embeddings.get(i));
|
||||
}
|
||||
|
||||
// 设置相似度分数(距离越小越相似)
|
||||
if (distances != null && i < distances.size()) {
|
||||
double score = 1.0 - distances.get(i);
|
||||
// 确保分数在合理范围内
|
||||
score = Math.max(0, Math.min(1, score));
|
||||
doc.setScore(score);
|
||||
}
|
||||
|
||||
resultDocs.add(doc);
|
||||
}
|
||||
|
||||
logger.debug("Found {} documents in search results", resultDocs.size());
|
||||
return resultDocs;
|
||||
} catch (Exception e) {
|
||||
logger.error("Failed to parse search results", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> List<T> extractResultsFromNestedList(Map<String, Object> responseMap, String key) {
|
||||
try {
|
||||
if (!responseMap.containsKey(key)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<?> outerList = (List<?>) responseMap.get(key);
|
||||
if (outerList == null || outerList.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Chroma返回的结果是嵌套列表,第一个元素是当前查询的结果
|
||||
return (List<T>) outerList.get(0);
|
||||
} catch (Exception e) {
|
||||
logger.warn("Failed to extract '{}' from response: {}", key, e.getMessage());
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, String> createHeaders() {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
|
||||
if (config.getApiKey() != null && !config.getApiKey().isEmpty()) {
|
||||
headers.put("X-Chroma-Token", config.getApiKey());
|
||||
}
|
||||
|
||||
// 添加租户和数据库信息(如果配置了)
|
||||
if (tenant != null && !tenant.isEmpty()) {
|
||||
headers.put("X-Chroma-Tenant", tenant);
|
||||
}
|
||||
|
||||
if (database != null && !database.isEmpty()) {
|
||||
headers.put("X-Chroma-Database", database);
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private <T> T executeWithRetry(HttpOperation<T> operation) throws IOException {
|
||||
int attempts = 0;
|
||||
IOException lastException = null;
|
||||
|
||||
while (attempts < MAX_RETRIES) {
|
||||
try {
|
||||
attempts++;
|
||||
return operation.execute();
|
||||
} catch (IOException e) {
|
||||
lastException = e;
|
||||
|
||||
// 如果是最后一次尝试,则抛出异常
|
||||
if (attempts >= MAX_RETRIES) {
|
||||
throw new IOException("Operation failed after " + MAX_RETRIES + " attempts: " + e.getMessage(), e);
|
||||
}
|
||||
|
||||
// 记录重试信息
|
||||
logger.warn("Operation failed (attempt {} of {}), retrying in {}ms: {}",
|
||||
attempts, MAX_RETRIES, RETRY_INTERVAL_MS, e.getMessage());
|
||||
|
||||
// 等待一段时间后重试
|
||||
try {
|
||||
Thread.sleep(RETRY_INTERVAL_MS);
|
||||
} catch (InterruptedException ie) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IOException("Retry interrupted", ie);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 这一行理论上不会执行到,但为了编译器满意
|
||||
throw lastException != null ? lastException : new IOException("Operation failed without exception");
|
||||
}
|
||||
|
||||
private String safeJsonSerialize(Map<String, Object> map) {
|
||||
// 使用标准的JSON序列化,但在实际应用中可以添加更多的安全检查
|
||||
try {
|
||||
return new com.google.gson.Gson().toJson(map);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to serialize request body to JSON", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Object parseJsonResponse(String json) {
|
||||
try {
|
||||
if (json == null || json.trim().isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
// Check if JSON starts with [ indicating an array
|
||||
if (json.trim().startsWith("[")) {
|
||||
return new com.google.gson.Gson().fromJson(json, List.class);
|
||||
} else {
|
||||
// Otherwise assume it's an object
|
||||
return new com.google.gson.Gson().fromJson(json, Map.class);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to parse JSON response: " + json, e);
|
||||
}
|
||||
}
|
||||
|
||||
private String sanitizeLogString(String input, int maxLength) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String sanitized = input.replaceAll("[\n\r]", " ");
|
||||
return sanitized.length() > maxLength ? sanitized.substring(0, maxLength) + "..." : sanitized;
|
||||
}
|
||||
|
||||
private String getCollectionName(StoreOptions options) {
|
||||
return options != null ? options.getCollectionNameOrDefault(collectionName) : collectionName;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建特定集合操作的URL,包含tenant和database
|
||||
*/
|
||||
private String buildCollectionUrl(String collectionId, String operation) {
|
||||
StringBuilder urlBuilder = new StringBuilder(baseUrl).append(BASE_API);
|
||||
|
||||
if (tenant != null && !tenant.isEmpty()) {
|
||||
urlBuilder.append("/tenants/").append(tenant);
|
||||
|
||||
if (database != null && !database.isEmpty()) {
|
||||
urlBuilder.append("/databases/").append(database);
|
||||
}
|
||||
}
|
||||
|
||||
urlBuilder.append("/collections/").append(collectionId).append("/").append(operation);
|
||||
return urlBuilder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the connection to Chroma database
|
||||
*/
|
||||
public void close() {
|
||||
// HttpClient类使用连接池管理,这里可以添加额外的资源清理逻辑
|
||||
logger.info("Chroma client closed");
|
||||
}
|
||||
|
||||
/**
|
||||
* 确保集合存在,如果不存在则创建
|
||||
*/
|
||||
private void ensureCollectionExists() throws IOException {
|
||||
try {
|
||||
// 尝试获取默认集合ID,如果能获取到则说明集合存在
|
||||
getCollectionId(collectionName);
|
||||
logger.debug("Collection '{}' exists", collectionName);
|
||||
} catch (IOException e) {
|
||||
// 如果获取集合ID失败,说明集合不存在,需要创建
|
||||
logger.info("Collection '{}' does not exist, creating...", collectionName);
|
||||
createCollection();
|
||||
logger.info("Collection '{}' created successfully", collectionName);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建集合列表URL,包含tenant和database
|
||||
*/
|
||||
private String buildCollectionsUrl() {
|
||||
StringBuilder urlBuilder = new StringBuilder(baseUrl).append(BASE_API);
|
||||
|
||||
if (tenant != null && !tenant.isEmpty()) {
|
||||
urlBuilder.append("/tenants/").append(tenant);
|
||||
|
||||
if (database != null && !database.isEmpty()) {
|
||||
urlBuilder.append("/databases/").append(database);
|
||||
}
|
||||
}
|
||||
|
||||
urlBuilder.append("/collections");
|
||||
return urlBuilder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* 函数式接口,用于封装HTTP操作以支持重试
|
||||
*/
|
||||
private interface HttpOperation<T> {
|
||||
T execute() throws IOException;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.easyagents.store.chroma;
|
||||
|
||||
import com.easyagents.core.store.DocumentStoreConfig;
|
||||
import com.easyagents.core.util.StringUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* ChromaVectorStoreConfig class provides configuration for ChromaVectorStore.
|
||||
*/
|
||||
public class ChromaVectorStoreConfig implements DocumentStoreConfig {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ChromaVectorStoreConfig.class);
|
||||
|
||||
private String host = "localhost";
|
||||
private int port = 8000;
|
||||
private String collectionName;
|
||||
private boolean autoCreateCollection = true;
|
||||
private String apiKey;
|
||||
private String tenant;
|
||||
private String database;
|
||||
|
||||
public ChromaVectorStoreConfig() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the host of Chroma database
|
||||
*
|
||||
* @return the host of Chroma database
|
||||
*/
|
||||
public String getHost() {
|
||||
return host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the host of Chroma database
|
||||
*
|
||||
* @param host the host of Chroma database
|
||||
*/
|
||||
public void setHost(String host) {
|
||||
this.host = host;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the port of Chroma database
|
||||
*
|
||||
* @return the port of Chroma database
|
||||
*/
|
||||
public int getPort() {
|
||||
return port;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the port of Chroma database
|
||||
*
|
||||
* @param port the port of Chroma database
|
||||
*/
|
||||
public void setPort(int port) {
|
||||
this.port = port;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the collection name of Chroma database
|
||||
*
|
||||
* @return the collection name of Chroma database
|
||||
*/
|
||||
public String getCollectionName() {
|
||||
return collectionName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the collection name of Chroma database
|
||||
*
|
||||
* @param collectionName the collection name of Chroma database
|
||||
*/
|
||||
public void setCollectionName(String collectionName) {
|
||||
this.collectionName = collectionName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get whether to automatically create the collection if it doesn't exist
|
||||
*
|
||||
* @return true if the collection should be created automatically, false otherwise
|
||||
*/
|
||||
public boolean isAutoCreateCollection() {
|
||||
return autoCreateCollection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set whether to automatically create the collection if it doesn't exist
|
||||
*
|
||||
* @param autoCreateCollection true if the collection should be created automatically, false otherwise
|
||||
*/
|
||||
public void setAutoCreateCollection(boolean autoCreateCollection) {
|
||||
this.autoCreateCollection = autoCreateCollection;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the API key of Chroma database
|
||||
*
|
||||
* @return the API key of Chroma database
|
||||
*/
|
||||
public String getApiKey() {
|
||||
return apiKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the API key of Chroma database
|
||||
*
|
||||
* @param apiKey the API key of Chroma database
|
||||
*/
|
||||
public void setApiKey(String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tenant of Chroma database
|
||||
*
|
||||
* @return the tenant of Chroma database
|
||||
*/
|
||||
public String getTenant() {
|
||||
return tenant;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the tenant of Chroma database
|
||||
*
|
||||
* @param tenant the tenant of Chroma database
|
||||
*/
|
||||
public void setTenant(String tenant) {
|
||||
this.tenant = tenant;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the database of Chroma database
|
||||
*
|
||||
* @return the database of Chroma database
|
||||
*/
|
||||
public String getDatabase() {
|
||||
return database;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the database of Chroma database
|
||||
*
|
||||
* @param database the database of Chroma database
|
||||
*/
|
||||
public void setDatabase(String database) {
|
||||
this.database = database;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkAvailable() {
|
||||
try {
|
||||
URL url = new URL(getBaseUrl() + "/api/v2/heartbeat");
|
||||
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
|
||||
connection.setRequestMethod("GET");
|
||||
connection.setConnectTimeout(5000);
|
||||
connection.setReadTimeout(5000);
|
||||
|
||||
if (apiKey != null && !apiKey.isEmpty()) {
|
||||
connection.setRequestProperty("X-Chroma-Token", apiKey);
|
||||
}
|
||||
|
||||
int responseCode = connection.getResponseCode();
|
||||
connection.disconnect();
|
||||
|
||||
return responseCode == 200;
|
||||
} catch (IOException e) {
|
||||
logger.warn("Chroma database is not available: {}", e.getMessage());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the base URL of Chroma database
|
||||
*
|
||||
* @return the base URL of Chroma database
|
||||
*/
|
||||
public String getBaseUrl() {
|
||||
return "http://" + host + ":" + port;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.easyagents.store.chroma;
|
||||
|
||||
import com.easyagents.core.document.Document;
|
||||
import com.easyagents.core.store.SearchWrapper;
|
||||
import com.easyagents.core.store.StoreOptions;
|
||||
import com.easyagents.core.store.StoreResult;
|
||||
import com.easyagents.core.model.client.HttpClient;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Assume;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* ChromaVectorStore的测试类,测试文档的存储、搜索、更新和删除功能
|
||||
* 包含连接检查和错误处理机制,支持在无真实Chroma服务器时跳过测试
|
||||
*/
|
||||
public class ChromaVectorStoreTest {
|
||||
|
||||
private static ChromaVectorStore store;
|
||||
private static String testTenant = "default_tenant";
|
||||
private static String testDatabase = "default_database";
|
||||
private static String testCollectionName = "test_collection";
|
||||
private static boolean isChromaAvailable = false;
|
||||
private static boolean useMock = false; // 设置为true可以在没有真实Chroma服务器时使用模拟模式
|
||||
|
||||
/**
|
||||
* 在测试开始前初始化ChromaVectorStore实例
|
||||
*/
|
||||
@BeforeClass
|
||||
public static void setUp() {
|
||||
// 创建配置对象
|
||||
ChromaVectorStoreConfig config = new ChromaVectorStoreConfig();
|
||||
config.setHost("localhost");
|
||||
config.setPort(8000);
|
||||
config.setCollectionName(testCollectionName);
|
||||
config.setTenant(testTenant);
|
||||
config.setDatabase(testDatabase);
|
||||
config.setAutoCreateCollection(true);
|
||||
|
||||
// 初始化存储实例
|
||||
try {
|
||||
store = new ChromaVectorStore(config);
|
||||
System.out.println("ChromaVectorStore initialized successfully.");
|
||||
|
||||
// 检查连接是否可用
|
||||
isChromaAvailable = checkChromaConnection(config);
|
||||
|
||||
if (!isChromaAvailable && !useMock) {
|
||||
System.out.println("Chroma server is not available. Tests will be skipped unless useMock is set to true.");
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to initialize ChromaVectorStore: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查Chroma服务器连接是否可用
|
||||
*/
|
||||
private static boolean checkChromaConnection(ChromaVectorStoreConfig config) {
|
||||
try {
|
||||
String baseUrl = "http://" + config.getHost() + ":" + config.getPort();
|
||||
String healthCheckUrl = baseUrl + "/api/v2/heartbeat";
|
||||
|
||||
HttpClient httpClient = new HttpClient();
|
||||
System.out.println("Checking Chroma server connection at: " + healthCheckUrl);
|
||||
|
||||
// 使用较短的超时时间进行健康检查
|
||||
String response = httpClient.get(healthCheckUrl);
|
||||
if (response != null) {
|
||||
System.out.println("Chroma server connection successful! Response: " + response);
|
||||
return true;
|
||||
} else {
|
||||
System.out.println("Chroma server connection failed: Empty response");
|
||||
return false;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.out.println("Chroma server connection failed: " + e.getMessage());
|
||||
System.out.println("Please ensure Chroma server is running on http://" + config.getHost() + ":" + config.getPort());
|
||||
System.out.println("To run tests without a real Chroma server, set 'useMock = true'");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否应该运行测试
|
||||
*/
|
||||
private void assumeChromaAvailable() {
|
||||
Assume.assumeTrue("Chroma server is not available and mock mode is disabled",
|
||||
isChromaAvailable || useMock);
|
||||
}
|
||||
|
||||
/**
|
||||
* 在所有测试完成后清理资源
|
||||
*/
|
||||
@AfterClass
|
||||
public static void tearDown() {
|
||||
if (store != null) {
|
||||
try {
|
||||
store.close();
|
||||
System.out.println("ChromaVectorStore closed successfully.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Error closing ChromaVectorStore: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试存储文档功能
|
||||
*/
|
||||
@Test
|
||||
public void testStoreDocuments() {
|
||||
assumeChromaAvailable();
|
||||
|
||||
System.out.println("Starting testStoreDocuments...");
|
||||
|
||||
// 创建测试文档
|
||||
List<Document> documents = createTestDocuments();
|
||||
|
||||
// 如果使用模拟模式,直接返回成功结果
|
||||
if (useMock) {
|
||||
System.out.println("Running in mock mode. Simulating store operation.");
|
||||
StoreResult mockResult = StoreResult.successWithIds(documents);
|
||||
assertTrue("Mock store operation should be successful", mockResult.isSuccess());
|
||||
assertEquals("All document IDs should be returned in mock mode",
|
||||
documents.size(), mockResult.ids().size());
|
||||
System.out.println("testStoreDocuments completed successfully in mock mode.");
|
||||
return;
|
||||
}
|
||||
|
||||
// 存储文档
|
||||
try {
|
||||
StoreResult result = store.doStore(documents, StoreOptions.DEFAULT);
|
||||
System.out.println("Store result: " + result);
|
||||
|
||||
// 验证存储是否成功
|
||||
assertTrue("Store operation should be successful", result.isSuccess());
|
||||
assertEquals("All document IDs should be returned", documents.size(), result.ids().size());
|
||||
|
||||
System.out.println("testStoreDocuments completed successfully.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to store documents: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
fail("Store operation failed with exception: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试搜索文档功能
|
||||
*/
|
||||
@Test
|
||||
public void testSearchDocuments() {
|
||||
assumeChromaAvailable();
|
||||
|
||||
System.out.println("Starting testSearchDocuments...");
|
||||
|
||||
// 创建测试文档
|
||||
List<Document> documents = createTestDocuments();
|
||||
|
||||
// 如果使用模拟模式
|
||||
if (useMock) {
|
||||
System.out.println("Running in mock mode. Simulating search operation.");
|
||||
// 模拟搜索结果,返回前3个文档
|
||||
List<Document> mockResults = new ArrayList<>(documents.subList(0, Math.min(3, documents.size())));
|
||||
for (int i = 0; i < mockResults.size(); i++) {
|
||||
mockResults.get(i).setScore(1.0 - i * 0.1); // 模拟相似度分数
|
||||
}
|
||||
|
||||
// 验证模拟结果
|
||||
assertNotNull("Mock search results should not be null", mockResults);
|
||||
assertFalse("Mock search results should not be empty", mockResults.isEmpty());
|
||||
assertTrue("Mock search results should have the correct maximum size", mockResults.size() <= 3);
|
||||
|
||||
System.out.println("testSearchDocuments completed successfully in mock mode.");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 首先存储一些测试文档
|
||||
store.doStore(documents, StoreOptions.DEFAULT);
|
||||
|
||||
// 创建搜索包装器
|
||||
SearchWrapper searchWrapper = new SearchWrapper();
|
||||
// 使用第一个文档的向量进行搜索
|
||||
searchWrapper.setVector(documents.get(0).getVector());
|
||||
searchWrapper.setMaxResults(3);
|
||||
|
||||
// 执行搜索
|
||||
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
|
||||
|
||||
// 验证搜索结果
|
||||
assertNotNull("Search results should not be null", searchResults);
|
||||
assertFalse("Search results should not be empty", searchResults.isEmpty());
|
||||
assertTrue("Search results should have the correct maximum size",
|
||||
searchResults.size() <= searchWrapper.getMaxResults());
|
||||
|
||||
// 打印搜索结果
|
||||
System.out.println("Search results:");
|
||||
for (Document doc : searchResults) {
|
||||
System.out.printf("id=%s, content=%s, vector=%s, score=%s\n",
|
||||
doc.getId(), doc.getContent(), Arrays.toString(doc.getVector()), doc.getScore());
|
||||
}
|
||||
|
||||
System.out.println("testSearchDocuments completed successfully.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to search documents: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
fail("Search operation failed with exception: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试更新文档功能
|
||||
*/
|
||||
@Test
|
||||
public void testUpdateDocuments() {
|
||||
assumeChromaAvailable();
|
||||
|
||||
System.out.println("Starting testUpdateDocuments...");
|
||||
|
||||
// 创建测试文档
|
||||
List<Document> documents = createTestDocuments();
|
||||
|
||||
// 如果使用模拟模式
|
||||
if (useMock) {
|
||||
System.out.println("Running in mock mode. Simulating update operation.");
|
||||
|
||||
// 修改文档内容
|
||||
Document updatedDoc = documents.get(0);
|
||||
String originalContent = updatedDoc.getContent();
|
||||
updatedDoc.setContent(originalContent + " [UPDATED]");
|
||||
|
||||
// 模拟更新结果
|
||||
StoreResult mockResult = StoreResult.successWithIds(Arrays.asList(updatedDoc));
|
||||
assertTrue("Mock update operation should be successful", mockResult.isSuccess());
|
||||
|
||||
System.out.println("testUpdateDocuments completed successfully in mock mode.");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 首先存储一些测试文档
|
||||
store.doStore(documents, StoreOptions.DEFAULT);
|
||||
|
||||
// 修改文档内容
|
||||
Document updatedDoc = documents.get(0);
|
||||
String originalContent = updatedDoc.getContent();
|
||||
updatedDoc.setContent(originalContent + " [UPDATED]");
|
||||
|
||||
// 执行更新
|
||||
StoreResult result = store.doUpdate(Arrays.asList(updatedDoc), StoreOptions.DEFAULT);
|
||||
|
||||
// 验证更新是否成功
|
||||
assertTrue("Update operation should be successful", result.isSuccess());
|
||||
|
||||
// 搜索更新后的文档以验证更改
|
||||
SearchWrapper searchWrapper = new SearchWrapper();
|
||||
searchWrapper.setVector(updatedDoc.getVector());
|
||||
searchWrapper.setMaxResults(1);
|
||||
|
||||
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
|
||||
assertTrue("Should find the updated document", !searchResults.isEmpty());
|
||||
assertEquals("Document content should be updated",
|
||||
updatedDoc.getContent(), searchResults.get(0).getContent());
|
||||
|
||||
System.out.println("testUpdateDocuments completed successfully.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to update documents: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
fail("Update operation failed with exception: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试删除文档功能
|
||||
*/
|
||||
@Test
|
||||
public void testDeleteDocuments() {
|
||||
assumeChromaAvailable();
|
||||
|
||||
System.out.println("Starting testDeleteDocuments...");
|
||||
|
||||
// 创建测试文档
|
||||
List<Document> documents = createTestDocuments();
|
||||
|
||||
// 如果使用模拟模式
|
||||
if (useMock) {
|
||||
System.out.println("Running in mock mode. Simulating delete operation.");
|
||||
|
||||
// 获取要删除的文档ID
|
||||
List<Object> idsToDelete = new ArrayList<>();
|
||||
idsToDelete.add(documents.get(0).getId());
|
||||
|
||||
// 模拟删除结果
|
||||
StoreResult mockResult = StoreResult.success();
|
||||
assertTrue("Mock delete operation should be successful", mockResult.isSuccess());
|
||||
|
||||
System.out.println("testDeleteDocuments completed successfully in mock mode.");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 首先存储一些测试文档
|
||||
store.doStore(documents, StoreOptions.DEFAULT);
|
||||
|
||||
// 获取要删除的文档ID
|
||||
List<Object> idsToDelete = new ArrayList<>();
|
||||
idsToDelete.add(documents.get(0).getId());
|
||||
|
||||
// 执行删除
|
||||
StoreResult result = store.doDelete(idsToDelete, StoreOptions.DEFAULT);
|
||||
|
||||
// 验证删除是否成功
|
||||
assertTrue("Delete operation should be successful", result.isSuccess());
|
||||
|
||||
// 尝试搜索已删除的文档
|
||||
SearchWrapper searchWrapper = new SearchWrapper();
|
||||
searchWrapper.setVector(documents.get(0).getVector());
|
||||
searchWrapper.setMaxResults(10);
|
||||
|
||||
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
|
||||
|
||||
// 检查结果中是否包含已删除的文档
|
||||
boolean deletedDocFound = searchResults.stream()
|
||||
.anyMatch(doc -> doc.getId().equals(documents.get(0).getId()));
|
||||
|
||||
assertFalse("Deleted document should not be found", deletedDocFound);
|
||||
|
||||
System.out.println("testDeleteDocuments completed successfully.");
|
||||
} catch (Exception e) {
|
||||
System.err.println("Failed to delete documents: " + e.getMessage());
|
||||
e.printStackTrace();
|
||||
fail("Delete operation failed with exception: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建测试文档
|
||||
*/
|
||||
private List<Document> createTestDocuments() {
|
||||
List<Document> documents = new ArrayList<>();
|
||||
|
||||
// 创建5个测试文档,每个文档都有不同的内容和向量
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Document doc = new Document();
|
||||
doc.setId("doc_" + i);
|
||||
doc.setContent("This is test document content " + i);
|
||||
doc.setTitle("Test Document " + i);
|
||||
|
||||
// 创建一个简单的向量,向量维度为10
|
||||
float[] vector = new float[10];
|
||||
for (int j = 0; j < vector.length; j++) {
|
||||
vector[j] = i + j * 0.1f;
|
||||
}
|
||||
doc.setVector(vector);
|
||||
|
||||
documents.add(doc);
|
||||
}
|
||||
|
||||
return documents;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user