初始化

This commit is contained in:
2026-02-22 18:55:40 +08:00
commit 8392cdd861
496 changed files with 45020 additions and 0 deletions

View File

@@ -0,0 +1,47 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>com.easyagents</groupId>
<artifactId>easy-agents-store</artifactId>
<version>${revision}</version>
</parent>
<name>easy-agents-store-chroma</name>
<artifactId>easy-agents-store-chroma</artifactId>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>com.easyagents</groupId>
<artifactId>easy-agents-core</artifactId>
</dependency>
<!-- <dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.10.1</version>
</dependency> -->
<!-- 测试依赖 -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,101 @@
/*
/*
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.easyagents.store.chroma;
import com.easyagents.core.store.condition.Condition;
import com.easyagents.core.store.condition.ConditionType;
import com.easyagents.core.store.condition.ExpressionAdaptor;
import com.easyagents.core.store.condition.Value;
import java.util.StringJoiner;
public class ChromaExpressionAdaptor implements ExpressionAdaptor {
public static final ChromaExpressionAdaptor DEFAULT = new ChromaExpressionAdaptor();
@Override
public String toOperationSymbol(ConditionType type) {
if (type == ConditionType.EQ) {
return " == ";
} else if (type == ConditionType.NE) {
return " != ";
} else if (type == ConditionType.GT) {
return " > ";
} else if (type == ConditionType.GE) {
return " >= ";
} else if (type == ConditionType.LT) {
return " < ";
} else if (type == ConditionType.LE) {
return " <= ";
} else if (type == ConditionType.IN) {
return " IN ";
}
return type.getDefaultSymbol();
}
@Override
public String toCondition(Condition condition) {
if (condition.getType() == ConditionType.BETWEEN) {
Object[] values = (Object[]) ((Value) condition.getRight()).getValue();
return "(" + toLeft(condition.getLeft())
+ toOperationSymbol(ConditionType.GE)
+ values[0] + " && "
+ toLeft(condition.getLeft())
+ toOperationSymbol(ConditionType.LE)
+ values[1] + ")";
}
return ExpressionAdaptor.super.toCondition(condition);
}
@Override
public String toValue(Condition condition, Object value) {
if (value == null) {
return "null";
}
if (condition.getType() == ConditionType.IN) {
Object[] values = (Object[]) value;
StringJoiner stringJoiner = new StringJoiner(",", "[", "]");
for (Object v : values) {
if (v != null) {
stringJoiner.add("\"" + v + "\"");
}
}
return stringJoiner.toString();
} else if (value instanceof String) {
return "\"" + value + "\"";
} else if (value instanceof Boolean) {
return ((Boolean) value).toString();
} else if (value instanceof Number) {
return value.toString();
}
return ExpressionAdaptor.super.toValue(condition, value);
}
public String toLeft(Object left) {
if (left instanceof String) {
String field = (String) left;
if (field.contains(".")) {
return field;
}
return field;
}
return left.toString();
}
}

View File

@@ -0,0 +1,794 @@
/*
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.easyagents.store.chroma;
import com.easyagents.core.document.Document;
import com.easyagents.core.store.DocumentStore;
import com.easyagents.core.store.SearchWrapper;
import com.easyagents.core.store.StoreOptions;
import com.easyagents.core.store.StoreResult;
import com.easyagents.core.store.condition.ExpressionAdaptor;
import com.easyagents.core.model.client.HttpClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
/**
* ChromaVectorStore class provides an interface to interact with Chroma Vector Database
* using direct HTTP calls to the Chroma REST API.
*/
public class ChromaVectorStore extends DocumentStore {
private static final Logger logger = LoggerFactory.getLogger(ChromaVectorStore.class);
private final String baseUrl;
private final String collectionName;
private final String tenant;
private final String database;
private final ChromaVectorStoreConfig config;
private final ExpressionAdaptor expressionAdaptor;
private final HttpClient httpClient;
private final int MAX_RETRIES = 3;
private final long RETRY_INTERVAL_MS = 1000;
private static final String BASE_API = "/api/v2";
public ChromaVectorStore(ChromaVectorStoreConfig config) {
Objects.requireNonNull(config, "ChromaVectorStoreConfig cannot be null");
this.baseUrl = config.getBaseUrl();
this.tenant = config.getTenant();
this.database = config.getDatabase();
this.collectionName = config.getCollectionName();
this.config = config;
this.expressionAdaptor = ChromaExpressionAdaptor.DEFAULT;
// 创建并配置HttpClient实例
this.httpClient = createHttpClient();
// 验证配置的有效性
validateConfig();
// 如果配置了自动创建集合,检查并创建集合
if (config.isAutoCreateCollection()) {
try {
// 确保租户和数据库存在
ensureTenantAndDatabaseExists();
// 确保集合存在
ensureCollectionExists();
} catch (Exception e) {
logger.warn("Failed to ensure collection exists: {}. Will retry on first operation.", e.getMessage());
}
}
}
private HttpClient createHttpClient() {
HttpClient client = new HttpClient();
return client;
}
private void validateConfig() {
if (baseUrl == null || baseUrl.isEmpty()) {
throw new IllegalArgumentException("Base URL cannot be empty");
}
if (!baseUrl.startsWith("http://") && !baseUrl.startsWith("https://")) {
throw new IllegalArgumentException("Base URL must start with http:// or https://");
}
}
/**
* 确保租户和数据库存在,如果不存在则创建
*/
private void ensureTenantAndDatabaseExists() {
try {
// 检查并创建租户
if (tenant != null && !tenant.isEmpty()) {
ensureTenantExists();
// 检查并创建数据库(如果租户已设置)
if (database != null && !database.isEmpty()) {
ensureDatabaseExists();
}
}
} catch (Exception e) {
logger.error("Error ensuring tenant and database exist", e);
}
}
/**
* 确保租户存在,如果不存在则创建
*/
private void ensureTenantExists() throws IOException {
String tenantUrl = baseUrl + BASE_API + "/tenants/" + tenant;
Map<String, String> headers = createHeaders();
try {
// 尝试获取租户信息
String responseBody = executeWithRetry(() -> httpClient.get(tenantUrl, headers));
logger.debug("Successfully verified tenant '{}' exists", tenant);
} catch (IOException e) {
// 如果获取失败,尝试创建租户
logger.info("Creating tenant '{}' as it does not exist", tenant);
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("name", tenant);
String createTenantUrl = baseUrl + BASE_API + "/tenants";
String jsonRequestBody = safeJsonSerialize(requestBody);
String responseBody = executeWithRetry(() -> httpClient.post(createTenantUrl, headers, jsonRequestBody));
logger.info("Successfully created tenant '{}'", tenant);
}
}
/**
* 确保数据库存在,如果不存在则创建
*/
private void ensureDatabaseExists() throws IOException {
if (tenant == null || tenant.isEmpty()) {
throw new IllegalStateException("Cannot create database without tenant");
}
String databaseUrl = baseUrl + BASE_API + "/tenants/" + tenant + "/databases/" + database;
Map<String, String> headers = createHeaders();
try {
// 尝试获取数据库信息
String responseBody = executeWithRetry(() -> httpClient.get(databaseUrl, headers));
logger.debug("Successfully verified database '{}' exists in tenant '{}'",
database, tenant);
} catch (IOException e) {
// 如果获取失败,尝试创建数据库
logger.info("Creating database '{}' in tenant '{}' as it does not exist",
database, tenant);
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("name", database);
String createDatabaseUrl = baseUrl + BASE_API + "/tenants/" + tenant + "/databases";
String jsonRequestBody = safeJsonSerialize(requestBody);
String responseBody = executeWithRetry(() -> httpClient.post(createDatabaseUrl, headers, jsonRequestBody));
logger.info("Successfully created database '{}' in tenant '{}'",
database, tenant);
}
}
/**
* 根据collectionName查询Collection ID
*/
private String getCollectionId(String collectionName) throws IOException {
String collectionsUrl = buildCollectionsUrl();
Map<String, String> headers = createHeaders();
String responseBody = executeWithRetry(() -> httpClient.get(collectionsUrl, headers));
if (responseBody == null) {
throw new IOException("Failed to get collections, no response");
}
Object responseObj = parseJsonResponse(responseBody);
List<Map<String, Object>> collections = new ArrayList<>();
// 处理不同格式的响应
if (responseObj instanceof Map) {
Map<String, Object> responseMap = (Map<String, Object>) responseObj;
if (responseMap.containsKey("collections") && responseMap.get("collections") instanceof List) {
collections = (List<Map<String, Object>>) responseMap.get("collections");
}
} else if (responseObj instanceof List) {
List<?> rawCollections = (List<?>) responseObj;
for (Object item : rawCollections) {
if (item instanceof Map) {
collections.add((Map<String, Object>) item);
}
}
}
// 查找指定名称的集合
for (Map<String, Object> collection : collections) {
if (collection.containsKey("name") && collectionName.equals(collection.get("name"))) {
return collection.get("id").toString();
}
}
throw new IOException("Collection not found: " + collectionName);
}
private void createCollection() throws IOException {
// 构建创建集合的API URL包含tenant和database
String createCollectionUrl = buildCollectionsUrl();
Map<String, String> headers = createHeaders();
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("name", collectionName);
String jsonRequestBody = safeJsonSerialize(requestBody);
String responseBody = executeWithRetry(() -> httpClient.post(createCollectionUrl, headers, jsonRequestBody));
if (responseBody == null) {
throw new IOException("Failed to create collection: no response");
}
try {
Object responseObj = parseJsonResponse(responseBody);
Map<String, Object> responseMap = null;
if (responseObj instanceof Map) {
responseMap = (Map<String, Object>) responseObj;
}
if (responseMap.containsKey("error")) {
throw new IOException("Failed to create collection: " + responseMap.get("error"));
}
logger.info("Collection '{}' created successfully", collectionName);
} catch (Exception e) {
throw new IOException("Failed to process collection creation response: " + e.getMessage(), e);
}
}
@Override
public StoreResult doStore(List<Document> documents, StoreOptions options) {
Objects.requireNonNull(documents, "Documents cannot be null");
if (documents.isEmpty()) {
logger.debug("No documents to store");
return StoreResult.success();
}
try {
// 确保集合存在
ensureCollectionExists();
String collectionName = getCollectionName(options);
List<String> ids = new ArrayList<>();
List<List<Double>> embeddings = new ArrayList<>();
List<Map<String, Object>> metadatas = new ArrayList<>();
List<String> documentsContent = new ArrayList<>();
for (Document doc : documents) {
ids.add(String.valueOf(doc.getId()));
if (doc.getVector() != null) {
List<Double> embedding = doc.getVectorAsDoubleList();
embeddings.add(embedding);
} else {
embeddings.add(null);
}
Map<String, Object> metadata = doc.getMetadataMap() != null ?
new HashMap<>(doc.getMetadataMap()) : new HashMap<>();
metadatas.add(metadata);
documentsContent.add(doc.getContent());
}
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("ids", ids);
requestBody.put("embeddings", embeddings);
requestBody.put("metadatas", metadatas);
requestBody.put("documents", documentsContent);
String collectionId = getCollectionId(collectionName);
// 构建包含tenant和database的完整URL
String collectionUrl = buildCollectionUrl(collectionId, "add");
Map<String, String> headers = createHeaders();
String jsonRequestBody = safeJsonSerialize(requestBody);
logger.debug("Storing {} documents to collection '{}'", documents.size(), collectionName);
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
if (responseBody == null) {
logger.error("Error storing documents: no response");
return StoreResult.fail();
}
Object responseObj = parseJsonResponse(responseBody);
Map<String, Object> responseMap = null;
if (responseObj instanceof Map) {
responseMap = (Map<String, Object>) responseObj;
}
if (responseMap.containsKey("error")) {
String errorMsg = "Error storing documents: " + responseMap.get("error");
logger.error(errorMsg);
return StoreResult.fail();
}
logger.debug("Successfully stored {} documents", documents.size());
return StoreResult.successWithIds(documents);
} catch (Exception e) {
logger.error("Error storing documents to Chroma", e);
return StoreResult.fail();
}
}
@Override
public StoreResult doDelete(Collection<?> ids, StoreOptions options) {
Objects.requireNonNull(ids, "IDs cannot be null");
if (ids.isEmpty()) {
logger.debug("No IDs to delete");
return StoreResult.success();
}
try {
// 确保集合存在
ensureCollectionExists();
String collectionName = getCollectionName(options);
List<String> stringIds = ids.stream()
.map(Object::toString)
.collect(Collectors.toList());
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("ids", stringIds);
String collectionId = getCollectionId(collectionName);
// 构建包含tenant和database的完整URL
String collectionUrl = buildCollectionUrl(collectionId, "delete");
Map<String, String> headers = createHeaders();
String jsonRequestBody = safeJsonSerialize(requestBody);
logger.debug("Deleting {} documents from collection '{}'", ids.size(), collectionName);
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
if (responseBody == null) {
logger.error("Error deleting documents: no response");
return StoreResult.fail();
}
Object responseObj = parseJsonResponse(responseBody);
Map<String, Object> responseMap = null;
if (responseObj instanceof Map) {
responseMap = (Map<String, Object>) responseObj;
}
if (responseMap.containsKey("error")) {
String errorMsg = "Error deleting documents: " + responseMap.get("error");
logger.error(errorMsg);
return StoreResult.fail();
}
logger.debug("Successfully deleted {} documents", ids.size());
return StoreResult.success();
} catch (Exception e) {
logger.error("Error deleting documents from Chroma", e);
return StoreResult.fail();
}
}
@Override
public StoreResult doUpdate(List<Document> documents, StoreOptions options) {
Objects.requireNonNull(documents, "Documents cannot be null");
if (documents.isEmpty()) {
logger.debug("No documents to update");
return StoreResult.success();
}
try {
// Chroma doesn't support direct update, so we delete and re-add
List<Object> ids = documents.stream().map(Document::getId).collect(Collectors.toList());
StoreResult deleteResult = doDelete(ids, options);
if (!deleteResult.isSuccess()) {
logger.warn("Delete failed during update operation: {}", deleteResult.toString());
// 尝试继续添加,因为可能有些文档是新的
}
StoreResult storeResult = doStore(documents, options);
if (storeResult.isSuccess()) {
logger.debug("Successfully updated {} documents", documents.size());
}
return storeResult;
} catch (Exception e) {
logger.error("Error updating documents in Chroma", e);
return StoreResult.fail();
}
}
@Override
public List<Document> doSearch(SearchWrapper wrapper, StoreOptions options) {
Objects.requireNonNull(wrapper, "SearchWrapper cannot be null");
try {
// 确保集合存在
ensureCollectionExists();
String collectionName = getCollectionName(options);
int limit = wrapper.getMaxResults() > 0 ? wrapper.getMaxResults() : 10;
Map<String, Object> requestBody = new HashMap<>();
// 检查查询条件是否有效
if (wrapper.getVector() == null && wrapper.getText() == null) {
throw new IllegalArgumentException("Either vector or text must be provided for search");
}
// 设置查询向量
if (wrapper.getVector() != null) {
List<Double> queryEmbedding = wrapper.getVectorAsDoubleList();
requestBody.put("query_embeddings", Collections.singletonList(queryEmbedding));
logger.debug("Performing vector search with dimension: {}", queryEmbedding.size());
} else if (wrapper.getText() != null) {
requestBody.put("query_texts", Collections.singletonList(wrapper.getText()));
logger.debug("Performing text search: {}", sanitizeLogString(wrapper.getText(), 100));
}
// 设置返回数量
requestBody.put("n_results", limit);
// 设置过滤条件
if (wrapper.getCondition() != null) {
try {
String whereClause = expressionAdaptor.toCondition(wrapper.getCondition());
// Chroma的where条件是JSON对象需要解析
Object whereObj = parseJsonResponse(whereClause);
Map<String, Object> whereMap = null;
if (whereObj instanceof Map) {
whereMap = (Map<String, Object>) whereObj;
}
requestBody.put("where", whereMap);
logger.debug("Search with filter condition: {}", whereClause);
} catch (Exception e) {
logger.warn("Failed to parse filter condition: {}, ignoring condition", e.getMessage());
}
}
String collectionId = getCollectionId(collectionName);
// 构建包含tenant和database的完整URL
String collectionUrl = buildCollectionUrl(collectionId, "query");
Map<String, String> headers = createHeaders();
String jsonRequestBody = safeJsonSerialize(requestBody);
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
if (responseBody == null) {
logger.error("Error searching documents: no response");
return Collections.emptyList();
}
Object responseObj = parseJsonResponse(responseBody);
Map<String, Object> responseMap = null;
if (responseObj instanceof Map) {
responseMap = (Map<String, Object>) responseObj;
}
// 检查响应是否包含error字段
if (responseMap.containsKey("error")) {
logger.error("Error searching documents: {}", responseMap.get("error"));
return Collections.emptyList();
}
// 解析结果
return parseSearchResults(responseMap);
} catch (Exception e) {
logger.error("Error searching documents in Chroma", e);
return Collections.emptyList();
}
}
/**
* 支持直接使用向量数组和topK参数的搜索方法
*/
public List<Document> searchInternal(double[] vector, int topK, StoreOptions options) {
Objects.requireNonNull(vector, "Vector cannot be null");
if (topK <= 0) {
topK = 10;
}
try {
// 确保集合存在
ensureCollectionExists();
String collectionName = getCollectionName(options);
Map<String, Object> requestBody = new HashMap<>();
// 设置查询向量
List<Double> queryEmbedding = Arrays.stream(vector)
.boxed()
.collect(Collectors.toList());
requestBody.put("query_embeddings", Collections.singletonList(queryEmbedding));
// 设置返回数量
requestBody.put("n_results", topK);
String collectionId = getCollectionId(collectionName);
// 构建包含tenant和database的完整URL
String collectionUrl = buildCollectionUrl(collectionId, "query");
Map<String, String> headers = createHeaders();
String jsonRequestBody = safeJsonSerialize(requestBody);
logger.debug("Performing direct vector search with dimension: {}", vector.length);
String responseBody = executeWithRetry(() -> httpClient.post(collectionUrl, headers, jsonRequestBody));
if (responseBody == null) {
logger.error("Error searching documents: no response");
return Collections.emptyList();
}
Object responseObj = parseJsonResponse(responseBody);
Map<String, Object> responseMap = null;
if (responseObj instanceof Map) {
responseMap = (Map<String, Object>) responseObj;
}
// 检查响应是否包含error字段
if (responseMap.containsKey("error")) {
logger.error("Error searching documents: {}", responseMap.get("error"));
return Collections.emptyList();
}
// 解析结果
return parseSearchResults(responseMap);
} catch (Exception e) {
logger.error("Error searching documents in Chroma", e);
return Collections.emptyList();
}
}
private List<Document> parseSearchResults(Map<String, Object> responseMap) {
try {
List<String> ids = extractResultsFromNestedList(responseMap, "ids");
List<String> documents = extractResultsFromNestedList(responseMap, "documents");
List<Map<String, Object>> metadatas = extractResultsFromNestedList(responseMap, "metadatas");
List<List<Double>> embeddings = extractResultsFromNestedList(responseMap, "embeddings");
List<Double> distances = extractResultsFromNestedList(responseMap, "distances");
if (ids == null || ids.isEmpty()) {
logger.debug("No documents found in search results");
return Collections.emptyList();
}
// 转换为Easy-Agents的Document格式
List<Document> resultDocs = new ArrayList<>();
for (int i = 0; i < ids.size(); i++) {
Document doc = new Document();
doc.setId(ids.get(i));
if (documents != null && i < documents.size()) {
doc.setContent(documents.get(i));
}
if (metadatas != null && i < metadatas.size()) {
doc.setMetadataMap(metadatas.get(i));
}
if (embeddings != null && i < embeddings.size() && embeddings.get(i) != null) {
doc.setVector(embeddings.get(i));
}
// 设置相似度分数(距离越小越相似)
if (distances != null && i < distances.size()) {
double score = 1.0 - distances.get(i);
// 确保分数在合理范围内
score = Math.max(0, Math.min(1, score));
doc.setScore(score);
}
resultDocs.add(doc);
}
logger.debug("Found {} documents in search results", resultDocs.size());
return resultDocs;
} catch (Exception e) {
logger.error("Failed to parse search results", e);
return Collections.emptyList();
}
}
@SuppressWarnings("unchecked")
private <T> List<T> extractResultsFromNestedList(Map<String, Object> responseMap, String key) {
try {
if (!responseMap.containsKey(key)) {
return null;
}
List<?> outerList = (List<?>) responseMap.get(key);
if (outerList == null || outerList.isEmpty()) {
return null;
}
// Chroma返回的结果是嵌套列表第一个元素是当前查询的结果
return (List<T>) outerList.get(0);
} catch (Exception e) {
logger.warn("Failed to extract '{}' from response: {}", key, e.getMessage());
return null;
}
}
private Map<String, String> createHeaders() {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
if (config.getApiKey() != null && !config.getApiKey().isEmpty()) {
headers.put("X-Chroma-Token", config.getApiKey());
}
// 添加租户和数据库信息(如果配置了)
if (tenant != null && !tenant.isEmpty()) {
headers.put("X-Chroma-Tenant", tenant);
}
if (database != null && !database.isEmpty()) {
headers.put("X-Chroma-Database", database);
}
return headers;
}
private <T> T executeWithRetry(HttpOperation<T> operation) throws IOException {
int attempts = 0;
IOException lastException = null;
while (attempts < MAX_RETRIES) {
try {
attempts++;
return operation.execute();
} catch (IOException e) {
lastException = e;
// 如果是最后一次尝试,则抛出异常
if (attempts >= MAX_RETRIES) {
throw new IOException("Operation failed after " + MAX_RETRIES + " attempts: " + e.getMessage(), e);
}
// 记录重试信息
logger.warn("Operation failed (attempt {} of {}), retrying in {}ms: {}",
attempts, MAX_RETRIES, RETRY_INTERVAL_MS, e.getMessage());
// 等待一段时间后重试
try {
Thread.sleep(RETRY_INTERVAL_MS);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new IOException("Retry interrupted", ie);
}
}
}
// 这一行理论上不会执行到,但为了编译器满意
throw lastException != null ? lastException : new IOException("Operation failed without exception");
}
private String safeJsonSerialize(Map<String, Object> map) {
// 使用标准的JSON序列化但在实际应用中可以添加更多的安全检查
try {
return new com.google.gson.Gson().toJson(map);
} catch (Exception e) {
throw new RuntimeException("Failed to serialize request body to JSON", e);
}
}
private Object parseJsonResponse(String json) {
try {
if (json == null || json.trim().isEmpty()) {
return null;
}
// Check if JSON starts with [ indicating an array
if (json.trim().startsWith("[")) {
return new com.google.gson.Gson().fromJson(json, List.class);
} else {
// Otherwise assume it's an object
return new com.google.gson.Gson().fromJson(json, Map.class);
}
} catch (Exception e) {
throw new RuntimeException("Failed to parse JSON response: " + json, e);
}
}
private String sanitizeLogString(String input, int maxLength) {
if (input == null) {
return null;
}
String sanitized = input.replaceAll("[\n\r]", " ");
return sanitized.length() > maxLength ? sanitized.substring(0, maxLength) + "..." : sanitized;
}
private String getCollectionName(StoreOptions options) {
return options != null ? options.getCollectionNameOrDefault(collectionName) : collectionName;
}
/**
* 构建特定集合操作的URL包含tenant和database
*/
private String buildCollectionUrl(String collectionId, String operation) {
StringBuilder urlBuilder = new StringBuilder(baseUrl).append(BASE_API);
if (tenant != null && !tenant.isEmpty()) {
urlBuilder.append("/tenants/").append(tenant);
if (database != null && !database.isEmpty()) {
urlBuilder.append("/databases/").append(database);
}
}
urlBuilder.append("/collections/").append(collectionId).append("/").append(operation);
return urlBuilder.toString();
}
/**
* Close the connection to Chroma database
*/
public void close() {
// HttpClient类使用连接池管理这里可以添加额外的资源清理逻辑
logger.info("Chroma client closed");
}
/**
* 确保集合存在,如果不存在则创建
*/
private void ensureCollectionExists() throws IOException {
try {
// 尝试获取默认集合ID如果能获取到则说明集合存在
getCollectionId(collectionName);
logger.debug("Collection '{}' exists", collectionName);
} catch (IOException e) {
// 如果获取集合ID失败说明集合不存在需要创建
logger.info("Collection '{}' does not exist, creating...", collectionName);
createCollection();
logger.info("Collection '{}' created successfully", collectionName);
}
}
/**
* 构建集合列表URL包含tenant和database
*/
private String buildCollectionsUrl() {
StringBuilder urlBuilder = new StringBuilder(baseUrl).append(BASE_API);
if (tenant != null && !tenant.isEmpty()) {
urlBuilder.append("/tenants/").append(tenant);
if (database != null && !database.isEmpty()) {
urlBuilder.append("/databases/").append(database);
}
}
urlBuilder.append("/collections");
return urlBuilder.toString();
}
/**
* 函数式接口用于封装HTTP操作以支持重试
*/
private interface HttpOperation<T> {
T execute() throws IOException;
}
}

View File

@@ -0,0 +1,203 @@
/*
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.easyagents.store.chroma;
import com.easyagents.core.store.DocumentStoreConfig;
import com.easyagents.core.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.HashMap;
import java.util.Map;
/**
* ChromaVectorStoreConfig class provides configuration for ChromaVectorStore.
*/
public class ChromaVectorStoreConfig implements DocumentStoreConfig {
private static final Logger logger = LoggerFactory.getLogger(ChromaVectorStoreConfig.class);
private String host = "localhost";
private int port = 8000;
private String collectionName;
private boolean autoCreateCollection = true;
private String apiKey;
private String tenant;
private String database;
public ChromaVectorStoreConfig() {
}
/**
* Get the host of Chroma database
*
* @return the host of Chroma database
*/
public String getHost() {
return host;
}
/**
* Set the host of Chroma database
*
* @param host the host of Chroma database
*/
public void setHost(String host) {
this.host = host;
}
/**
* Get the port of Chroma database
*
* @return the port of Chroma database
*/
public int getPort() {
return port;
}
/**
* Set the port of Chroma database
*
* @param port the port of Chroma database
*/
public void setPort(int port) {
this.port = port;
}
/**
* Get the collection name of Chroma database
*
* @return the collection name of Chroma database
*/
public String getCollectionName() {
return collectionName;
}
/**
* Set the collection name of Chroma database
*
* @param collectionName the collection name of Chroma database
*/
public void setCollectionName(String collectionName) {
this.collectionName = collectionName;
}
/**
* Get whether to automatically create the collection if it doesn't exist
*
* @return true if the collection should be created automatically, false otherwise
*/
public boolean isAutoCreateCollection() {
return autoCreateCollection;
}
/**
* Set whether to automatically create the collection if it doesn't exist
*
* @param autoCreateCollection true if the collection should be created automatically, false otherwise
*/
public void setAutoCreateCollection(boolean autoCreateCollection) {
this.autoCreateCollection = autoCreateCollection;
}
/**
* Get the API key of Chroma database
*
* @return the API key of Chroma database
*/
public String getApiKey() {
return apiKey;
}
/**
* Set the API key of Chroma database
*
* @param apiKey the API key of Chroma database
*/
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
/**
* Get the tenant of Chroma database
*
* @return the tenant of Chroma database
*/
public String getTenant() {
return tenant;
}
/**
* Set the tenant of Chroma database
*
* @param tenant the tenant of Chroma database
*/
public void setTenant(String tenant) {
this.tenant = tenant;
}
/**
* Get the database of Chroma database
*
* @return the database of Chroma database
*/
public String getDatabase() {
return database;
}
/**
* Set the database of Chroma database
*
* @param database the database of Chroma database
*/
public void setDatabase(String database) {
this.database = database;
}
@Override
public boolean checkAvailable() {
try {
URL url = new URL(getBaseUrl() + "/api/v2/heartbeat");
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setConnectTimeout(5000);
connection.setReadTimeout(5000);
if (apiKey != null && !apiKey.isEmpty()) {
connection.setRequestProperty("X-Chroma-Token", apiKey);
}
int responseCode = connection.getResponseCode();
connection.disconnect();
return responseCode == 200;
} catch (IOException e) {
logger.warn("Chroma database is not available: {}", e.getMessage());
return false;
}
}
/**
* Get the base URL of Chroma database
*
* @return the base URL of Chroma database
*/
public String getBaseUrl() {
return "http://" + host + ":" + port;
}
}

View File

@@ -0,0 +1,383 @@
/*
* Copyright (c) 2023-2026, Easy-Agents (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.easyagents.store.chroma;
import com.easyagents.core.document.Document;
import com.easyagents.core.store.SearchWrapper;
import com.easyagents.core.store.StoreOptions;
import com.easyagents.core.store.StoreResult;
import com.easyagents.core.model.client.HttpClient;
import org.junit.AfterClass;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
/**
* ChromaVectorStore的测试类测试文档的存储、搜索、更新和删除功能
* 包含连接检查和错误处理机制支持在无真实Chroma服务器时跳过测试
*/
public class ChromaVectorStoreTest {
private static ChromaVectorStore store;
private static String testTenant = "default_tenant";
private static String testDatabase = "default_database";
private static String testCollectionName = "test_collection";
private static boolean isChromaAvailable = false;
private static boolean useMock = false; // 设置为true可以在没有真实Chroma服务器时使用模拟模式
/**
* 在测试开始前初始化ChromaVectorStore实例
*/
@BeforeClass
public static void setUp() {
// 创建配置对象
ChromaVectorStoreConfig config = new ChromaVectorStoreConfig();
config.setHost("localhost");
config.setPort(8000);
config.setCollectionName(testCollectionName);
config.setTenant(testTenant);
config.setDatabase(testDatabase);
config.setAutoCreateCollection(true);
// 初始化存储实例
try {
store = new ChromaVectorStore(config);
System.out.println("ChromaVectorStore initialized successfully.");
// 检查连接是否可用
isChromaAvailable = checkChromaConnection(config);
if (!isChromaAvailable && !useMock) {
System.out.println("Chroma server is not available. Tests will be skipped unless useMock is set to true.");
}
} catch (Exception e) {
System.err.println("Failed to initialize ChromaVectorStore: " + e.getMessage());
e.printStackTrace();
}
}
/**
* 检查Chroma服务器连接是否可用
*/
private static boolean checkChromaConnection(ChromaVectorStoreConfig config) {
try {
String baseUrl = "http://" + config.getHost() + ":" + config.getPort();
String healthCheckUrl = baseUrl + "/api/v2/heartbeat";
HttpClient httpClient = new HttpClient();
System.out.println("Checking Chroma server connection at: " + healthCheckUrl);
// 使用较短的超时时间进行健康检查
String response = httpClient.get(healthCheckUrl);
if (response != null) {
System.out.println("Chroma server connection successful! Response: " + response);
return true;
} else {
System.out.println("Chroma server connection failed: Empty response");
return false;
}
} catch (Exception e) {
System.out.println("Chroma server connection failed: " + e.getMessage());
System.out.println("Please ensure Chroma server is running on http://" + config.getHost() + ":" + config.getPort());
System.out.println("To run tests without a real Chroma server, set 'useMock = true'");
return false;
}
}
/**
* 检查是否应该运行测试
*/
private void assumeChromaAvailable() {
Assume.assumeTrue("Chroma server is not available and mock mode is disabled",
isChromaAvailable || useMock);
}
/**
* 在所有测试完成后清理资源
*/
@AfterClass
public static void tearDown() {
if (store != null) {
try {
store.close();
System.out.println("ChromaVectorStore closed successfully.");
} catch (Exception e) {
System.err.println("Error closing ChromaVectorStore: " + e.getMessage());
}
}
}
/**
* 测试存储文档功能
*/
@Test
public void testStoreDocuments() {
assumeChromaAvailable();
System.out.println("Starting testStoreDocuments...");
// 创建测试文档
List<Document> documents = createTestDocuments();
// 如果使用模拟模式,直接返回成功结果
if (useMock) {
System.out.println("Running in mock mode. Simulating store operation.");
StoreResult mockResult = StoreResult.successWithIds(documents);
assertTrue("Mock store operation should be successful", mockResult.isSuccess());
assertEquals("All document IDs should be returned in mock mode",
documents.size(), mockResult.ids().size());
System.out.println("testStoreDocuments completed successfully in mock mode.");
return;
}
// 存储文档
try {
StoreResult result = store.doStore(documents, StoreOptions.DEFAULT);
System.out.println("Store result: " + result);
// 验证存储是否成功
assertTrue("Store operation should be successful", result.isSuccess());
assertEquals("All document IDs should be returned", documents.size(), result.ids().size());
System.out.println("testStoreDocuments completed successfully.");
} catch (Exception e) {
System.err.println("Failed to store documents: " + e.getMessage());
e.printStackTrace();
fail("Store operation failed with exception: " + e.getMessage());
}
}
/**
* 测试搜索文档功能
*/
@Test
public void testSearchDocuments() {
assumeChromaAvailable();
System.out.println("Starting testSearchDocuments...");
// 创建测试文档
List<Document> documents = createTestDocuments();
// 如果使用模拟模式
if (useMock) {
System.out.println("Running in mock mode. Simulating search operation.");
// 模拟搜索结果返回前3个文档
List<Document> mockResults = new ArrayList<>(documents.subList(0, Math.min(3, documents.size())));
for (int i = 0; i < mockResults.size(); i++) {
mockResults.get(i).setScore(1.0 - i * 0.1); // 模拟相似度分数
}
// 验证模拟结果
assertNotNull("Mock search results should not be null", mockResults);
assertFalse("Mock search results should not be empty", mockResults.isEmpty());
assertTrue("Mock search results should have the correct maximum size", mockResults.size() <= 3);
System.out.println("testSearchDocuments completed successfully in mock mode.");
return;
}
try {
// 首先存储一些测试文档
store.doStore(documents, StoreOptions.DEFAULT);
// 创建搜索包装器
SearchWrapper searchWrapper = new SearchWrapper();
// 使用第一个文档的向量进行搜索
searchWrapper.setVector(documents.get(0).getVector());
searchWrapper.setMaxResults(3);
// 执行搜索
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
// 验证搜索结果
assertNotNull("Search results should not be null", searchResults);
assertFalse("Search results should not be empty", searchResults.isEmpty());
assertTrue("Search results should have the correct maximum size",
searchResults.size() <= searchWrapper.getMaxResults());
// 打印搜索结果
System.out.println("Search results:");
for (Document doc : searchResults) {
System.out.printf("id=%s, content=%s, vector=%s, score=%s\n",
doc.getId(), doc.getContent(), Arrays.toString(doc.getVector()), doc.getScore());
}
System.out.println("testSearchDocuments completed successfully.");
} catch (Exception e) {
System.err.println("Failed to search documents: " + e.getMessage());
e.printStackTrace();
fail("Search operation failed with exception: " + e.getMessage());
}
}
/**
* 测试更新文档功能
*/
@Test
public void testUpdateDocuments() {
assumeChromaAvailable();
System.out.println("Starting testUpdateDocuments...");
// 创建测试文档
List<Document> documents = createTestDocuments();
// 如果使用模拟模式
if (useMock) {
System.out.println("Running in mock mode. Simulating update operation.");
// 修改文档内容
Document updatedDoc = documents.get(0);
String originalContent = updatedDoc.getContent();
updatedDoc.setContent(originalContent + " [UPDATED]");
// 模拟更新结果
StoreResult mockResult = StoreResult.successWithIds(Arrays.asList(updatedDoc));
assertTrue("Mock update operation should be successful", mockResult.isSuccess());
System.out.println("testUpdateDocuments completed successfully in mock mode.");
return;
}
try {
// 首先存储一些测试文档
store.doStore(documents, StoreOptions.DEFAULT);
// 修改文档内容
Document updatedDoc = documents.get(0);
String originalContent = updatedDoc.getContent();
updatedDoc.setContent(originalContent + " [UPDATED]");
// 执行更新
StoreResult result = store.doUpdate(Arrays.asList(updatedDoc), StoreOptions.DEFAULT);
// 验证更新是否成功
assertTrue("Update operation should be successful", result.isSuccess());
// 搜索更新后的文档以验证更改
SearchWrapper searchWrapper = new SearchWrapper();
searchWrapper.setVector(updatedDoc.getVector());
searchWrapper.setMaxResults(1);
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
assertTrue("Should find the updated document", !searchResults.isEmpty());
assertEquals("Document content should be updated",
updatedDoc.getContent(), searchResults.get(0).getContent());
System.out.println("testUpdateDocuments completed successfully.");
} catch (Exception e) {
System.err.println("Failed to update documents: " + e.getMessage());
e.printStackTrace();
fail("Update operation failed with exception: " + e.getMessage());
}
}
/**
* 测试删除文档功能
*/
@Test
public void testDeleteDocuments() {
assumeChromaAvailable();
System.out.println("Starting testDeleteDocuments...");
// 创建测试文档
List<Document> documents = createTestDocuments();
// 如果使用模拟模式
if (useMock) {
System.out.println("Running in mock mode. Simulating delete operation.");
// 获取要删除的文档ID
List<Object> idsToDelete = new ArrayList<>();
idsToDelete.add(documents.get(0).getId());
// 模拟删除结果
StoreResult mockResult = StoreResult.success();
assertTrue("Mock delete operation should be successful", mockResult.isSuccess());
System.out.println("testDeleteDocuments completed successfully in mock mode.");
return;
}
try {
// 首先存储一些测试文档
store.doStore(documents, StoreOptions.DEFAULT);
// 获取要删除的文档ID
List<Object> idsToDelete = new ArrayList<>();
idsToDelete.add(documents.get(0).getId());
// 执行删除
StoreResult result = store.doDelete(idsToDelete, StoreOptions.DEFAULT);
// 验证删除是否成功
assertTrue("Delete operation should be successful", result.isSuccess());
// 尝试搜索已删除的文档
SearchWrapper searchWrapper = new SearchWrapper();
searchWrapper.setVector(documents.get(0).getVector());
searchWrapper.setMaxResults(10);
List<Document> searchResults = store.doSearch(searchWrapper, StoreOptions.DEFAULT);
// 检查结果中是否包含已删除的文档
boolean deletedDocFound = searchResults.stream()
.anyMatch(doc -> doc.getId().equals(documents.get(0).getId()));
assertFalse("Deleted document should not be found", deletedDocFound);
System.out.println("testDeleteDocuments completed successfully.");
} catch (Exception e) {
System.err.println("Failed to delete documents: " + e.getMessage());
e.printStackTrace();
fail("Delete operation failed with exception: " + e.getMessage());
}
}
/**
* 创建测试文档
*/
private List<Document> createTestDocuments() {
List<Document> documents = new ArrayList<>();
// 创建5个测试文档每个文档都有不同的内容和向量
for (int i = 0; i < 5; i++) {
Document doc = new Document();
doc.setId("doc_" + i);
doc.setContent("This is test document content " + i);
doc.setTitle("Test Document " + i);
// 创建一个简单的向量向量维度为10
float[] vector = new float[10];
for (int j = 0; j < vector.length; j++) {
vector[j] = i + j * 0.1f;
}
doc.setVector(vector);
documents.add(doc);
}
return documents;
}
}