Compare commits

...

4 Commits

Author SHA1 Message Date
Xin@@Gar
f0266e13c5 refactor: improve type annotations in HitTestingService (#27838)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Skip Duplicate Checks (push) Waiting to run
Main CI Pipeline / Check Changed Files (push) Blocked by required conditions
Main CI Pipeline / Run API Tests (push) Blocked by required conditions
Main CI Pipeline / Skip API Tests (push) Blocked by required conditions
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Tests (push) Blocked by required conditions
Main CI Pipeline / Skip Web Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Skip Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Blocked by required conditions
Main CI Pipeline / Run VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Skip VDB Tests (push) Blocked by required conditions
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Run DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / Skip DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-04-13 10:31:31 +00:00
Yunlu Wen
ae898652b2 refactor: move vdb implementations to workspaces (#34900)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-04-13 08:56:43 +00:00
Ke Wang
c34f67495c refactor(api): type WorkflowRun.to_dict with WorkflowRunDict TypedDict (#35047)
Co-authored-by: Ke Wang <ke@pika.art>
2026-04-13 08:30:28 +00:00
hj24
815c536e05 fix: optimize trigger long running read transactions (#35046) 2026-04-13 08:22:54 +00:00
227 changed files with 2070 additions and 1016 deletions

View File

@@ -92,6 +92,7 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests

View File

@@ -21,8 +21,9 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies
# Install Python dependencies (workspace members under providers/vdb/)
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@@ -1,4 +1,3 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
HOLOGRES_TOKENIZER: str = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
HOLOGRES_DISTANCE_METHOD: str = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@@ -0,0 +1,87 @@
"""Vector store backend discovery.
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
declares third-party dependencies and registers ``importlib`` entry points in group
``dify.vector_backends`` (see each package's ``pyproject.toml``).
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
distribution; entry points take precedence when both exist.
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
"""
from __future__ import annotations
import importlib
import logging
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
logger = logging.getLogger(__name__)
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
# module_path:class_name — optional fallback when no distribution registers the backend.
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
def clear_vector_factory_cache() -> None:
"""Drop lazily loaded factories (for tests or plugin reload)."""
_VECTOR_FACTORY_CACHE.clear()
def _vector_backend_entry_points():
return entry_points().select(group="dify.vector_backends")
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
for ep in _vector_backend_entry_points():
if ep.name != vector_type:
continue
try:
loaded = ep.load()
except Exception:
logger.exception("Failed to load vector backend entry point %s", ep.name)
raise
return loaded # type: ignore[return-value]
return None
def _unsupported(vector_type: str) -> ValueError:
installed = sorted(ep.name for ep in _vector_backend_entry_points())
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
return ValueError(
f"Vector store {vector_type!r} is not supported.{available_msg} "
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
"or register a dify.vector_backends entry point."
)
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
if not target:
raise _unsupported(vector_type)
module_path, _, attr = target.partition(":")
module = importlib.import_module(module_path)
return getattr(module, attr) # type: ignore[no-any-return]
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
if vector_type in _VECTOR_FACTORY_CACHE:
return _VECTOR_FACTORY_CACHE[vector_type]
plugin_cls = _load_plugin_factory(vector_type)
if plugin_cls is not None:
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
return plugin_cls
cls = _load_builtin_factory(vector_type)
_VECTOR_FACTORY_CACHE[vector_type] = cls
return cls

View File

@@ -9,6 +9,7 @@ from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
@@ -85,137 +86,7 @@ class Vector:
@staticmethod
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
match vector_type:
case VectorType.CHROMA:
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
return ChromaVectorFactory
case VectorType.MILVUS:
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
return MilvusVectorFactory
case VectorType.ALIBABACLOUD_MYSQL:
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVectorFactory,
)
return AlibabaCloudMySQLVectorFactory
case VectorType.MYSCALE:
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
return MyScaleVectorFactory
case VectorType.PGVECTOR:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
case VectorType.VASTBASE:
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
return PGVectoRSFactory
case VectorType.QDRANT:
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
return QdrantVectorFactory
case VectorType.RELYT:
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
return RelytVectorFactory
case VectorType.ELASTICSEARCH:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
return ElasticSearchVectorFactory
case VectorType.ELASTICSEARCH_JA:
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
ElasticSearchJaVectorFactory,
)
return ElasticSearchJaVectorFactory
case VectorType.TIDB_VECTOR:
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
return TiDBVectorFactory
case VectorType.WEAVIATE:
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
return WeaviateVectorFactory
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
case VectorType.ANALYTICDB:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
return AnalyticdbVectorFactory
case VectorType.COUCHBASE:
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
return CouchbaseVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
return VikingDBVectorFactory
case VectorType.UPSTASH:
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
return UpstashVectorFactory
case VectorType.TIDB_ON_QDRANT:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE | VectorType.SEEKDB:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
return OceanBaseVectorFactory
case VectorType.OPENGAUSS:
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
return OpenGaussFactory
case VectorType.TABLESTORE:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
case VectorType.HUAWEI_CLOUD:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case VectorType.CLICKZETTA:
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
return ClickzettaVectorFactory
case VectorType.IRIS:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
return get_vector_factory_class(vector_type)
def create(self, texts: list | None = None, **kwargs):
if texts:

View File

@@ -1,10 +1,19 @@
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
:class:`AbstractVectorTest` and helper functions live here so package tests can import
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
``tests.*`` package.
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
auto-discovered by pytest for all package tests.
"""
import uuid
from unittest.mock import MagicMock
import pytest
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions import ext_redis
from models.dataset import Dataset
@@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
return doc
@pytest.fixture
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
# set
ext_redis.redis_client.set = MagicMock(return_value=None)
# lock
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
ext_redis.redis_client.lock = mock_redis_lock
class AbstractVectorTest:
vector: BaseVector
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4())

View File

@@ -671,6 +671,29 @@ class Workflow(Base): # bug
return str(d)
class WorkflowRunDict(TypedDict):
id: str
tenant_id: str
app_id: str
workflow_id: str
type: WorkflowType
triggered_from: WorkflowRunTriggeredFrom
version: str
graph: Mapping[str, Any]
inputs: Mapping[str, Any]
status: WorkflowExecutionStatus
outputs: Mapping[str, Any]
error: str | None
elapsed_time: float
total_tokens: int
total_steps: int
created_by_role: CreatorUserRole
created_by: str
created_at: datetime
finished_at: datetime | None
exceptions_count: int
class WorkflowRun(Base):
"""
Workflow Run
@@ -790,29 +813,29 @@ class WorkflowRun(Base):
def workflow(self):
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"workflow_id": self.workflow_id,
"type": self.type,
"triggered_from": self.triggered_from,
"version": self.version,
"graph": self.graph_dict,
"inputs": self.inputs_dict,
"status": self.status,
"outputs": self.outputs_dict,
"error": self.error,
"elapsed_time": self.elapsed_time,
"total_tokens": self.total_tokens,
"total_steps": self.total_steps,
"created_by_role": self.created_by_role,
"created_by": self.created_by,
"created_at": self.created_at,
"finished_at": self.finished_at,
"exceptions_count": self.exceptions_count,
}
def to_dict(self) -> WorkflowRunDict:
return WorkflowRunDict(
id=self.id,
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_id=self.workflow_id,
type=self.type,
triggered_from=self.triggered_from,
version=self.version,
graph=self.graph_dict,
inputs=self.inputs_dict,
status=self.status,
outputs=self.outputs_dict,
error=self.error,
elapsed_time=self.elapsed_time,
total_tokens=self.total_tokens,
total_steps=self.total_steps,
created_by_role=self.created_by_role,
created_by=self.created_by,
created_at=self.created_at,
finished_at=self.finished_at,
exceptions_count=self.exceptions_count,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":

12
api/providers/README.md Normal file
View File

@@ -0,0 +1,12 @@
# Providers
This directory holds **optional workspace packages** that plug into Difys API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions.
## Developing Providers
- [VDB Providers](vdb/README.md)
## Tests
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).

View File

@@ -0,0 +1,58 @@
# VDB providers
This directory contains all VDB providers.
## Architecture
1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins.
2. **Each provider** (`api/providers/vdb/<backend>/`) implements those contracts and registers an entry point.
3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`).
### Interfaces
| Piece | Role |
|--------|----------|
| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. |
| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. |
| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. |
| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. |
The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation.
### Entry point name must match the vector type string
Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"``tidb_on_qdrant`).
In `pyproject.toml`:
```toml
[project.entry-points."dify.vector_backends"]
pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory"
```
The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`.
### How registration works
1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache.
2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**.
3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance).
4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**.
After you change a providers `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environments dist-info matches the project metadata.
### Package layout (VDB)
Each backend usually follows:
- `api/providers/vdb/<backend>/pyproject.toml` — project name `dify-vdb-<backend>`, dependencies, entry points.
- `api/providers/vdb/<backend>/src/dify_vdb_<python_package>/` — implementation (e.g. `PGVector`, `PGVectorFactory`).
See `vdb/pgvector/` as a reference implementation.
### Wiring a new backend into the API workspace
The API uses a **uv workspace** (`api/pyproject.toml`):
1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members.
2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`.
3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle.

View File

@@ -0,0 +1,22 @@
from unittest.mock import MagicMock
import pytest
from extensions import ext_redis
@pytest.fixture(autouse=True)
def _init_mock_redis():
"""Ensure redis_client has a backing client so __getattr__ never raises."""
if ext_redis.redis_client._client is None:
ext_redis.redis_client.initialize(MagicMock())
@pytest.fixture
def setup_mock_redis(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None))
monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None))
mock_redis_lock = MagicMock()
mock_redis_lock.__enter__ = MagicMock()
mock_redis_lock.__exit__ = MagicMock()
monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock)

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-alibabacloud-mysql"
version = "0.0.1"
dependencies = [
"mysql-connector-python>=9.3.0",
]
description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)."
[project.entry-points."dify.vector_backends"]
alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,9 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
import pytest
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
def test_validate_distance_function_accepts_supported_values():

View File

@@ -3,11 +3,11 @@ import unittest
from unittest.mock import MagicMock, patch
import pytest
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import (
AlibabaCloudMySQLVector,
AlibabaCloudMySQLVectorConfig,
)
from core.rag.models.document import Document
try:
@@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Sample embeddings
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_init(self, mock_pool_class):
"""Test AlibabaCloudMySQLVector initialization."""
# Mock the connection pool
@@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert alibabacloud_mysql_vector.distance_function == "cosine"
assert alibabacloud_mysql_vector.pool is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_collection(self, mock_redis, mock_pool_class):
"""Test collection creation."""
# Mock Redis operations
@@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
mock_redis.set.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_success(self, mock_pool_class):
"""Test successful vector support check."""
# Mock the connection pool
@@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
assert vector_store is not None
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_failure(self, mock_pool_class):
"""Test vector support check failure."""
# Mock the connection pool
@@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_vector_support_check_function_error(self, mock_pool_class):
"""Test vector support check with function not found error."""
# Mock the connection pool
@@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "RDS MySQL Vector functions are not available" in str(context.value)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
def test_create_documents(self, mock_redis, mock_pool_class):
"""Test creating documents with embeddings."""
# Setup mocks
@@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "doc1" in result
assert "doc2" in result
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_add_texts(self, mock_pool_class):
"""Test adding texts to the vector store."""
# Mock the connection pool
@@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(result) == 2
mock_cursor.executemany.assert_called_once()
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_text_exists(self, mock_pool_class):
"""Test checking if text exists."""
# Mock the connection pool
@@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "SELECT id FROM" in last_call[0][0]
assert last_call[0][1] == ("doc1",)
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_text_not_exists(self, mock_pool_class):
"""Test checking if text does not exist."""
# Mock the connection pool
@@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert not exists
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_get_by_ids(self, mock_pool_class):
"""Test getting documents by IDs."""
# Mock the connection pool
@@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "Test document 1"
assert docs[1].page_content == "Test document 2"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_get_by_ids_empty_list(self, mock_pool_class):
"""Test getting documents with empty ID list."""
# Mock the connection pool
@@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids(self, mock_pool_class):
"""Test deleting documents by IDs."""
# Mock the connection pool
@@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "DELETE FROM" in delete_call[0][0]
assert delete_call[0][1] == ["doc1", "doc2"]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids_empty_list(self, mock_pool_class):
"""Test deleting with empty ID list."""
# Mock the connection pool
@@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
assert len(delete_calls) == 0
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
"""Test deleting when table doesn't exist."""
# Mock the connection pool
@@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
# Should not raise an exception
vector_store.delete_by_ids(["doc1"])
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_by_metadata_field(self, mock_pool_class):
"""Test deleting documents by metadata field."""
# Mock the connection pool
@@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
assert delete_call[0][1] == ("$.document_id", "dataset1")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_cosine(self, mock_pool_class):
"""Test vector search with cosine distance."""
# Mock the connection pool
@@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
assert docs[0].metadata["distance"] == 0.1
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_euclidean(self, mock_pool_class):
"""Test vector search with euclidean distance."""
config = AlibabaCloudMySQLVectorConfig(
@@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_with_filter(self, mock_pool_class):
"""Test vector search with document ID filter."""
# Mock the connection pool
@@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0]
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
"""Test vector search with score threshold."""
# Mock the connection pool
@@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert len(docs) == 1
assert docs[0].page_content == "High similarity document"
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
"""Test vector search with invalid top_k."""
# Mock the connection pool
@@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError):
vector_store.search_by_vector(query_vector, top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text(self, mock_pool_class):
"""Test full-text search."""
# Mock the connection pool
@@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
assert docs[0].page_content == "This document contains machine learning content"
assert docs[0].metadata["score"] == 1.5
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text_with_filter(self, mock_pool_class):
"""Test full-text search with document ID filter."""
# Mock the connection pool
@@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
search_call = search_calls[0]
assert "AND JSON_UNQUOTE" in search_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
"""Test full-text search with invalid top_k."""
# Mock the connection pool
@@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
with pytest.raises(ValueError):
vector_store.search_by_full_text("test", top_k="invalid")
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_delete_collection(self, mock_pool_class):
"""Test deleting the entire collection."""
# Mock the connection pool
@@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
drop_call = drop_calls[0]
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
@patch(
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
)
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
def test_unsupported_distance_function(self, mock_pool_class):
"""Test that Pydantic validation rejects unsupported distance functions."""
# Test that creating config with unsupported distance function raises ValidationError

View File

@@ -0,0 +1,15 @@
[project]
name = "dify-vdb-analyticdb"
version = "0.0.1"
dependencies = [
"alibabacloud_gpdb20160503~=5.2.0",
"alibabacloud_tea_openapi~=0.4.3",
"clickhouse-connect~=0.15.0",
]
description = "Dify vector store backend (dify-vdb-analyticdb)."
[project.entry-points."dify.vector_backends"]
analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -2,16 +2,16 @@ import json
from typing import Any
from configs import dify_config
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
from models.dataset import Dataset

View File

@@ -1,9 +1,8 @@
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
class AnalyticdbVectorTest(AbstractVectorTest):

View File

@@ -1,12 +1,12 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module
import pytest
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
from core.rag.models.document import Document

View File

@@ -4,13 +4,13 @@ import types
from types import SimpleNamespace
from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
AnalyticdbVectorOpenAPI,
AnalyticdbVectorOpenAPIConfig,
)
from core.rag.models.document import Document

View File

@@ -2,14 +2,14 @@ from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import MagicMock
import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module
import psycopg2.errors
import pytest
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
from dify_vdb_analyticdb.analyticdb_vector_sql import (
AnalyticdbVectorBySql,
AnalyticdbVectorBySqlConfig,
)
from core.rag.models.document import Document

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-baidu"
version = "0.0.1"
dependencies = [
"pymochow==2.4.0",
]
description = "Dify vector store backend (dify-vdb-baidu)."
[project.entry-points."dify.vector_backends"]
baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,6 @@
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.baiduvectordb",
)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
class BaiduVectorTest(AbstractVectorTest):

View File

@@ -124,7 +124,7 @@ def _build_fake_pymochow_modules():
def baidu_module(monkeypatch):
for name, module in _build_fake_pymochow_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.baidu.baidu_vector as module
import dify_vdb_baidu.baidu_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,13 @@
[project]
name = "dify-vdb-chroma"
version = "0.0.1"
dependencies = [
"chromadb==0.5.20",
]
description = "Dify vector store backend (dify-vdb-chroma)."
[project.entry-points."dify.vector_backends"]
chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,13 +1,11 @@
import chromadb
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
get_example_text,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ChromaVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -47,7 +47,7 @@ def _build_fake_chroma_modules():
def chroma_module(monkeypatch):
fake_chroma = _build_fake_chroma_modules()
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
import core.rag.datasource.vdb.chroma.chroma_vector as module
import dify_vdb_chroma.chroma_vector as module
return importlib.reload(module)

View File

@@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers:
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-clickzetta"
version = "0.0.1"
dependencies = [
"clickzetta-connector-python>=0.8.102",
]
description = "Dify vector store backend (dify-vdb-clickzetta)."
[project.entry-points."dify.vector_backends"]
clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -2,10 +2,10 @@ import contextlib
import os
import pytest
from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
class TestClickzettaVector(AbstractVectorTest):
@@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest):
"""
@pytest.fixture
def vector_store(self):
def vector_store(self, setup_mock_redis):
"""Create a Clickzetta vector store instance for testing."""
# Skip test if Clickzetta credentials are not configured
if not os.getenv("CLICKZETTA_USERNAME"):
pytest.skip("CLICKZETTA_USERNAME is not configured")
if not os.getenv("CLICKZETTA_PASSWORD"):
@@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest):
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
batch_size=10, # Small batch size for testing
batch_size=10,
enable_inverted_index=True,
analyzer_type="chinese",
analyzer_mode="smart",
vector_distance_function="cosine_distance",
)
with setup_mock_redis():
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
yield vector
yield vector
# Cleanup: delete the test collection
with contextlib.suppress(Exception):
vector.delete()
with contextlib.suppress(Exception):
vector.delete()
def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""

View File

@@ -3,16 +3,19 @@
Test Clickzetta integration in Docker environment
"""
import logging
import os
import time
import httpx
from clickzetta import connect
logger = logging.getLogger(__name__)
def test_clickzetta_connection():
"""Test direct connection to Clickzetta"""
print("=== Testing direct Clickzetta connection ===")
logger.info("=== Testing direct Clickzetta connection ===")
try:
conn = connect(
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
@@ -25,100 +28,93 @@ def test_clickzetta_connection():
)
with conn.cursor() as cursor:
# Test basic connectivity
cursor.execute("SELECT 1 as test")
result = cursor.fetchone()
print(f"✓ Connection test: {result}")
logger.info("✓ Connection test: %s", result)
# Check if our test table exists
cursor.execute("SHOW TABLES IN dify")
tables = cursor.fetchall()
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"])
# Check if test collection exists
test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:")
logger.info("✓ Table structure for %s:", test_collection)
for col in columns:
print(f" - {col[0]}: {col[1]}")
logger.info(" - %s: %s", col[0], col[1])
# Check for indexes
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
indexes = cursor.fetchall()
print(f"✓ Indexes on {test_collection}:")
logger.info("✓ Indexes on %s:", test_collection)
for idx in indexes:
print(f" - {idx}")
logger.info(" - %s", idx)
return True
except Exception as e:
print(f"✗ Connection test failed: {e}")
except Exception:
logger.exception("✗ Connection test failed")
return False
def test_dify_api():
"""Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===")
logger.info("\n=== Testing Dify API ===")
base_url = "http://localhost:5001"
# Wait for API to be ready
max_retries = 30
for i in range(max_retries):
try:
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
logger.info("✓ Dify API is ready")
break
except:
if i == max_retries - 1:
print("✗ Dify API is not responding")
logger.exception("✗ Dify API is not responding")
return False
time.sleep(2)
# Check vector store configuration
try:
# This is a simplified check - in production, you'd use proper auth
print("✓ Dify is configured to use Clickzetta as vector store")
logger.info("✓ Dify is configured to use Clickzetta as vector store")
return True
except Exception as e:
print(f"✗ API test failed: {e}")
except Exception:
logger.exception("✗ API test failed")
return False
def verify_table_structure():
"""Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===")
logger.info("\n=== Verifying Table Structure ===")
expected_columns = {
"id": "VARCHAR",
"page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
"metadata": "VARCHAR",
"vector": "ARRAY<FLOAT>",
}
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
print("✓ Expected table structure:")
logger.info("✓ Expected table structure:")
for col, dtype in expected_columns.items():
print(f" - {col}: {dtype}")
logger.info(" - %s: %s", col, dtype)
print("\n✓ Required metadata fields:")
logger.info("\n✓ Required metadata fields:")
for field in expected_metadata_fields:
print(f" - {field}")
logger.info(" - %s", field)
print("\n✓ Index requirements:")
print(" - Vector index (HNSW) on 'vector' column")
print(" - Full-text index on 'page_content' (optional)")
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
print(" - Functional index on metadata->>'$.document_id' (recommended)")
logger.info("\n✓ Index requirements:")
logger.info(" - Vector index (HNSW) on 'vector' column")
logger.info(" - Full-text index on 'page_content' (optional)")
logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)")
logger.info(" - Functional index on metadata->>'$.document_id' (recommended)")
return True
def main():
"""Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n")
logger.info("Starting Clickzetta integration tests for Dify Docker\n")
tests = [
("Direct Clickzetta Connection", test_clickzetta_connection),
@@ -131,33 +127,34 @@ def main():
try:
success = test_func()
results.append((test_name, success))
except Exception as e:
print(f"\n{test_name} crashed: {e}")
except Exception:
logger.exception("\n%s crashed", test_name)
results.append((test_name, False))
# Summary
print("\n" + "=" * 50)
print("Test Summary:")
print("=" * 50)
logger.info("\n%s", "=" * 50)
logger.info("Test Summary:")
logger.info("=" * 50)
passed = sum(1 for _, success in results if success)
total = len(results)
for test_name, success in results:
status = "✅ PASSED" if success else "❌ FAILED"
print(f"{test_name}: {status}")
logger.info("%s: %s", test_name, status)
print(f"\nTotal: {passed}/{total} tests passed")
logger.info("\nTotal: %s/%s tests passed", passed, total)
if passed == total:
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
print("\nNext steps:")
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
print("2. Access Dify at http://localhost:3000")
print("3. Create a dataset and test vector storage with Clickzetta")
logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
logger.info("\nNext steps:")
logger.info(
"1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d"
)
logger.info("2. Access Dify at http://localhost:3000")
logger.info("3. Create a dataset and test vector storage with Clickzetta")
return 0
else:
print("\n⚠️ Some tests failed. Please check the errors above.")
logger.error("\n⚠️ Some tests failed. Please check the errors above.")
return 1

View File

@@ -47,7 +47,7 @@ def _build_fake_clickzetta_module():
@pytest.fixture
def clickzetta_module(monkeypatch):
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
import dify_vdb_clickzetta.clickzetta_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-couchbase"
version = "0.0.1"
dependencies = [
"couchbase~=4.6.0",
]
description = "Dify vector store backend (dify-vdb-couchbase)."
[project.entry-points."dify.vector_backends"]
couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,12 +1,14 @@
import logging
import subprocess
import time
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
logger = logging.getLogger(__name__)
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
@@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
)
if result.stdout.strip() == "healthy":
print(f"{service_name} is healthy!")
logger.info("%s is healthy!", service_name)
return True
else:
print(f"Waiting for {service_name} to be healthy...")
logger.info("Waiting for %s to be healthy...", service_name)
time.sleep(10)
raise TimeoutError(f"{service_name} did not become healthy in time")

View File

@@ -154,7 +154,7 @@ def couchbase_module(monkeypatch):
for name, module in _build_fake_couchbase_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
import dify_vdb_couchbase.couchbase_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,15 @@
[project]
name = "dify-vdb-elasticsearch"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-elasticsearch)."
[project.entry-points."dify.vector_backends"]
elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory"
elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -4,14 +4,14 @@ from typing import Any
from flask import current_app
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from dify_vdb_elasticsearch.elasticsearch_vector import (
ElasticSearchConfig,
ElasticSearchVector,
ElasticSearchVectorFactory,
)
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from extensions.ext_redis import redis_client
from models.dataset import Dataset

View File

@@ -1,10 +1,9 @@
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module
import dify_vdb_elasticsearch.elasticsearch_vector as base_module
importlib.reload(base_module)
return importlib.reload(ja_module)

View File

@@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
import dify_vdb_elasticsearch.elasticsearch_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-hologres"
version = "0.0.1"
dependencies = [
"holo-search-sdk>=0.4.2",
]
description = "Dify vector store backend (dify-vdb-hologres)."
[project.entry-points."dify.vector_backends"]
hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,7 +1,7 @@
import json
import logging
import time
from typing import Any
from typing import Any, cast
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
@@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory):
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER,
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER),
distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD),
base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE),
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
),

View File

@@ -7,13 +7,10 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from psycopg import sql as psql
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
class MockSearchQuery:
"""Mock query builder for search_vector and search_text results."""
def __init__(self, table_name: str, search_type: str):
self._table_name = table_name
self._search_type = search_type
@@ -32,17 +29,13 @@ class MockSearchQuery:
return self
def _apply_filter(self, row: dict[str, Any]) -> bool:
"""Apply the filter SQL to check if a row matches."""
if self._filter_sql is None:
return True
# Extract literals (the document IDs) from the filter SQL
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
if not literals:
return True
# Get the document_id from the row's meta field
meta = row.get("meta", "{}")
if isinstance(meta, str):
meta = json.loads(meta)
@@ -54,22 +47,17 @@ class MockSearchQuery:
data = _mock_tables.get(self._table_name, {})
results = []
for row in list(data.values())[: self._limit_val]:
# Apply filter if present
if not self._apply_filter(row):
continue
if self._search_type == "vector":
# row format expected by _process_vector_results: (distance, id, text, meta)
results.append((0.1, row["id"], row["text"], row["meta"]))
else:
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
return results
class MockTable:
"""Mock table object returned by client.open_table()."""
def __init__(self, table_name: str):
self._table_name = table_name
@@ -97,7 +85,6 @@ class MockTable:
def _extract_sql_template(query) -> str:
"""Extract the SQL template string from a psycopg Composed object."""
if isinstance(query, psql.Composed):
for part in query:
if isinstance(part, psql.SQL):
@@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str:
def _extract_identifiers_and_literals(query) -> list[Any]:
"""Extract Identifier and Literal values from a psycopg Composed object."""
values: list[Any] = []
if isinstance(query, psql.Composed):
for part in query:
@@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
elif isinstance(part, psql.Literal):
values.append(("literal", part._obj))
elif isinstance(part, psql.Composed):
# Handles SQL(...).join(...) for IN clauses
for sub in part:
if isinstance(sub, psql.Literal):
values.append(("literal", sub._obj))
@@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
class MockHologresClient:
"""Mock holo_search_sdk client that stores data in memory."""
def connect(self):
pass
@@ -141,21 +124,18 @@ class MockHologresClient:
params = _extract_identifiers_and_literals(query)
if "CREATE TABLE" in template.upper():
# Extract table name from first identifier
table_name = next((v for t, v in params if t == "ident"), "unknown")
if table_name not in _mock_tables:
_mock_tables[table_name] = {}
return None
if "SELECT 1" in template:
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
table_name = next((v for t, v in params if t == "ident"), "")
doc_id = next((v for t, v in params if t == "literal"), "")
data = _mock_tables.get(table_name, {})
return [(1,)] if doc_id in data else []
if "SELECT id" in template:
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
table_name = next((v for t, v in params if t == "ident"), "")
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
@@ -166,12 +146,10 @@ class MockHologresClient:
if "DELETE" in template.upper():
table_name = next((v for t, v in params if t == "ident"), "")
if "id IN" in template:
# delete_by_ids
ids_to_delete = [v for t, v in params if t == "literal"]
for did in ids_to_delete:
_mock_tables.get(table_name, {}).pop(did, None)
elif "meta->>" in template:
# delete_by_metadata_field
literals = [v for t, v in params if t == "literal"]
key = literals[0] if len(literals) > 0 else ""
value = literals[1] if len(literals) > 1 else ""
@@ -190,7 +168,6 @@ class MockHologresClient:
def mock_connect(**kwargs):
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
return MockHologresClient()

View File

@@ -2,16 +2,11 @@ import os
import uuid
from typing import cast
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
from core.rag.models.document import Document
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.hologres",
)
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"

View File

@@ -42,7 +42,7 @@ def hologres_module(monkeypatch):
for name, module in _build_fake_hologres_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.hologres.hologres_vector as module
import dify_vdb_hologres.hologres_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-huawei-cloud"
version = "0.0.1"
dependencies = [
"elasticsearch==8.14.0",
]
description = "Dify vector store backend (dify-vdb-huawei-cloud)."
[project.entry-points."dify.vector_backends"]
huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,6 @@
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
pytest_plugins = (
"tests.integration_tests.vdb.test_vector_store",
"tests.integration_tests.vdb.__mock.huaweicloudvectordb",
)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
class HuaweiCloudVectorTest(AbstractVectorTest):

View File

@@ -33,7 +33,7 @@ def huawei_module(monkeypatch):
for name, module in _build_fake_elasticsearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module
import dify_vdb_huawei_cloud.huawei_cloud_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-iris"
version = "0.0.1"
dependencies = [
"intersystems-irispython>=5.1.0",
]
description = "Dify vector store backend (dify-vdb-iris)."
[project.entry-points."dify.vector_backends"]
iris = "dify_vdb_iris.iris_vector:IrisVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,12 +1,11 @@
"""Integration tests for IRIS vector database."""
from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_iris.iris_vector import IrisVector, IrisVectorConfig
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class IrisVectorTest(AbstractVectorTest):
"""Test suite for IRIS vector store implementation."""

View File

@@ -26,7 +26,7 @@ def _build_fake_iris_module():
def iris_module(monkeypatch):
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
import core.rag.datasource.vdb.iris.iris_vector as module
import dify_vdb_iris.iris_vector as module
reloaded = importlib.reload(module)
reloaded._pool_instance = None

View File

@@ -0,0 +1,15 @@
[project]
name = "dify-vdb-lindorm"
version = "0.0.1"
dependencies = [
"opensearch-py==3.1.0",
"tenacity>=8.0.0",
]
description = "Dify vector store backend (dify-vdb-lindorm)."
[project.entry-points."dify.vector_backends"]
lindorm = "dify_vdb_lindorm.lindorm_vector:LindormVectorStoreFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,9 +1,8 @@
import os
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
class Config:

View File

@@ -51,7 +51,7 @@ def lindorm_module(monkeypatch):
for name, module in _build_fake_opensearch_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.lindorm.lindorm_vector as module
import dify_vdb_lindorm.lindorm_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-matrixone"
version = "0.0.1"
dependencies = [
"mo-vector~=0.1.13",
]
description = "Dify vector store backend (dify-vdb-matrixone)."
[project.entry-points."dify.vector_backends"]
matrixone = "dify_vdb_matrixone.matrixone_vector:MatrixoneVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,9 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MatrixoneVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -36,7 +36,7 @@ def matrixone_module(monkeypatch):
for name, module in _build_fake_mo_vector_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.matrixone.matrixone_vector as module
import dify_vdb_matrixone.matrixone_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-milvus"
version = "0.0.1"
dependencies = [
"pymilvus~=2.6.12",
]
description = "Dify vector store backend (dify-vdb-milvus)."
[project.entry-points."dify.vector_backends"]
milvus = "dify_vdb_milvus.milvus_vector:MilvusVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,11 +1,10 @@
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
get_example_text,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MilvusVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -103,7 +103,7 @@ def milvus_module(monkeypatch):
for name, module in _build_fake_pymilvus_modules().items():
monkeypatch.setitem(sys.modules, name, module)
import core.rag.datasource.vdb.milvus.milvus_vector as module
import dify_vdb_milvus.milvus_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,14 @@
[project]
name = "dify-vdb-myscale"
version = "0.0.1"
dependencies = [
"clickhouse-connect~=0.15.0",
]
description = "Dify vector store backend (dify-vdb-myscale)."
[project.entry-points."dify.vector_backends"]
myscale = "dify_vdb_myscale.myscale_vector:MyScaleVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -1,10 +1,9 @@
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector
from tests.integration_tests.vdb.test_vector_store import (
from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
class MyScaleVectorTest(AbstractVectorTest):
def __init__(self):

View File

@@ -42,7 +42,7 @@ def myscale_module(monkeypatch):
fake_module = _build_fake_clickhouse_connect_module()
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
import core.rag.datasource.vdb.myscale.myscale_vector as module
import dify_vdb_myscale.myscale_vector as module
return importlib.reload(module)

View File

@@ -0,0 +1,16 @@
[project]
name = "dify-vdb-oceanbase"
version = "0.0.1"
dependencies = [
"pyobvector~=0.2.17",
"mysql-connector-python>=9.3.0",
]
description = "Dify vector store backend (dify-vdb-oceanbase)."
[project.entry-points."dify.vector_backends"]
oceanbase = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
seekdb = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
[tool.setuptools.packages.find]
where = ["src"]

View File

@@ -2,11 +2,12 @@
Benchmark: OceanBase vector store old (single-row) vs new (batch) insertion,
metadata query with/without functional index, and vector search across metrics.
Usage:
uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase
Usage (from repo root):
uv run --project api python api/packages/dify-vdb-oceanbase/tests/bench_oceanbase.py
"""
import json
import logging
import random
import statistics
import time
@@ -16,6 +17,8 @@ from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_d
from sqlalchemy import JSON, Column, String, text
from sqlalchemy.dialects.mysql import LONGTEXT
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@@ -114,7 +117,7 @@ def bench_metadata_query(client, table, doc_id, with_index=False):
try:
client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))")
except Exception:
pass # already exists
logger.debug("Index idx_metadata_doc_id already exists, skipping creation")
sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val")
times = []
@@ -164,11 +167,11 @@ def main():
client = _make_client()
client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True)
print("=" * 70)
print("OceanBase Vector Store — Performance Benchmark")
print(f" Endpoint : {HOST}:{PORT}")
print(f" Vec dim : {VEC_DIM}")
print("=" * 70)
logger.info("=" * 70)
logger.info("OceanBase Vector Store — Performance Benchmark")
logger.info(" Endpoint : %s:%s", HOST, PORT)
logger.info(" Vec dim : %s", VEC_DIM)
logger.info("=" * 70)
# ------------------------------------------------------------------
# 1. Insertion benchmark
@@ -187,10 +190,10 @@ def main():
t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100)
speedup = t_single / t_batch if t_batch > 0 else float("inf")
print(f"\n[Insert {n_docs} docs]")
print(f" Single-row : {t_single:.2f}s")
print(f" Batch(100) : {t_batch:.2f}s")
print(f" Speedup : {speedup:.1f}x")
logger.info("\n[Insert %s docs]", n_docs)
logger.info(" Single-row : %.2fs", t_single)
logger.info(" Batch(100) : %.2fs", t_batch)
logger.info(" Speedup : %.1fx", speedup)
# ------------------------------------------------------------------
# 2. Metadata query benchmark (use the 1000-doc batch table)
@@ -203,16 +206,16 @@ def main():
res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1"))
doc_id_1000 = res.fetchone()[0]
print("\n[Metadata filter query — 1000 rows, by document_id]")
logger.info("\n[Metadata filter query — 1000 rows, by document_id]")
times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False)
print(f" Without index : {_fmt(times_no_idx)}")
logger.info(" Without index : %s", _fmt(times_no_idx))
times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True)
print(f" With index : {_fmt(times_with_idx)}")
logger.info(" With index : %s", _fmt(times_with_idx))
# ------------------------------------------------------------------
# 3. Vector search benchmark — across metrics
# ------------------------------------------------------------------
print("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
logger.info("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
for metric in ["l2", "cosine", "inner_product"]:
tbl_vs = f"bench_vs_{metric}"
@@ -222,7 +225,7 @@ def main():
rows_vs, _ = _gen_rows(1000)
bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100)
times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20)
print(f" {metric:15s}: {_fmt(times)}")
logger.info(" %-15s: %s", metric, _fmt(times))
_drop(client_pooled, tbl_vs)
# ------------------------------------------------------------------
@@ -232,9 +235,9 @@ def main():
_drop(client, f"bench_single_{n}")
_drop(client, f"bench_batch_{n}")
print("\n" + "=" * 70)
print("Benchmark complete.")
print("=" * 70)
logger.info("\n%s", "=" * 70)
logger.info("Benchmark complete.")
logger.info("=" * 70)
if __name__ == "__main__":

View File

@@ -1,15 +1,13 @@
import pytest
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
from dify_vdb_oceanbase.oceanbase_vector import (
OceanBaseVector,
OceanBaseVectorConfig,
)
from tests.integration_tests.vdb.test_vector_store import (
from core.rag.datasource.vdb.vector_integration_test_support import (
AbstractVectorTest,
)
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
@pytest.fixture
def oceanbase_vector():

Some files were not shown because too many files have changed in this diff Show More