mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 12:02:44 +00:00
Compare commits
4 Commits
codex/base
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0266e13c5 | ||
|
|
ae898652b2 | ||
|
|
c34f67495c | ||
|
|
815c536e05 |
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@@ -92,6 +92,7 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
|
||||
@@ -21,8 +21,9 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
@@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Vector store backend discovery.
|
||||
|
||||
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
|
||||
declares third-party dependencies and registers ``importlib`` entry points in group
|
||||
``dify.vector_backends`` (see each package's ``pyproject.toml``).
|
||||
|
||||
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
|
||||
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
|
||||
|
||||
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
|
||||
distribution; entry points take precedence when both exist.
|
||||
|
||||
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from importlib.metadata import entry_points
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
|
||||
|
||||
# module_path:class_name — optional fallback when no distribution registers the backend.
|
||||
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
|
||||
|
||||
|
||||
def clear_vector_factory_cache() -> None:
|
||||
"""Drop lazily loaded factories (for tests or plugin reload)."""
|
||||
_VECTOR_FACTORY_CACHE.clear()
|
||||
|
||||
|
||||
def _vector_backend_entry_points():
|
||||
return entry_points().select(group="dify.vector_backends")
|
||||
|
||||
|
||||
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
|
||||
for ep in _vector_backend_entry_points():
|
||||
if ep.name != vector_type:
|
||||
continue
|
||||
try:
|
||||
loaded = ep.load()
|
||||
except Exception:
|
||||
logger.exception("Failed to load vector backend entry point %s", ep.name)
|
||||
raise
|
||||
return loaded # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
|
||||
def _unsupported(vector_type: str) -> ValueError:
|
||||
installed = sorted(ep.name for ep in _vector_backend_entry_points())
|
||||
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
|
||||
return ValueError(
|
||||
f"Vector store {vector_type!r} is not supported.{available_msg} "
|
||||
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
|
||||
"or register a dify.vector_backends entry point."
|
||||
)
|
||||
|
||||
|
||||
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
|
||||
if not target:
|
||||
raise _unsupported(vector_type)
|
||||
module_path, _, attr = target.partition(":")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
|
||||
if vector_type in _VECTOR_FACTORY_CACHE:
|
||||
return _VECTOR_FACTORY_CACHE[vector_type]
|
||||
|
||||
plugin_cls = _load_plugin_factory(vector_type)
|
||||
if plugin_cls is not None:
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
|
||||
return plugin_cls
|
||||
|
||||
cls = _load_builtin_factory(vector_type)
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = cls
|
||||
return cls
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
@@ -85,137 +86,7 @@ class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.ALIBABACLOUD_MYSQL:
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVectorFactory,
|
||||
)
|
||||
|
||||
return AlibabaCloudMySQLVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.VASTBASE:
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
|
||||
|
||||
return VastbaseVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
return QdrantVectorFactory
|
||||
case VectorType.RELYT:
|
||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||
|
||||
return RelytVectorFactory
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
return TiDBVectorFactory
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
|
||||
return TencentVectorFactory
|
||||
case VectorType.ORACLE:
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
|
||||
|
||||
return OracleVectorFactory
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
return CouchbaseVectorFactory
|
||||
case VectorType.BAIDU:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
return BaiduVectorFactory
|
||||
case VectorType.VIKINGDB:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case VectorType.UPSTASH:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case VectorType.TIDB_ON_QDRANT:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE | VectorType.SEEKDB:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case VectorType.OPENGAUSS:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case VectorType.TABLESTORE:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
return HuaweiCloudVectorFactory
|
||||
case VectorType.MATRIXONE:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case VectorType.HOLOGRES:
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
|
||||
|
||||
return HologresVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
|
||||
|
||||
:class:`AbstractVectorTest` and helper functions live here so package tests can import
|
||||
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
|
||||
``tests.*`` package.
|
||||
|
||||
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
|
||||
auto-discovered by pytest for all package tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
@@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
vector: BaseVector
|
||||
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
@@ -671,6 +671,29 @@ class Workflow(Base): # bug
|
||||
return str(d)
|
||||
|
||||
|
||||
class WorkflowRunDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
type: WorkflowType
|
||||
triggered_from: WorkflowRunTriggeredFrom
|
||||
version: str
|
||||
graph: Mapping[str, Any]
|
||||
inputs: Mapping[str, Any]
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any]
|
||||
error: str | None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_by_role: CreatorUserRole
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
exceptions_count: int
|
||||
|
||||
|
||||
class WorkflowRun(Base):
|
||||
"""
|
||||
Workflow Run
|
||||
@@ -790,29 +813,29 @@ class WorkflowRun(Base):
|
||||
def workflow(self):
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type,
|
||||
"triggered_from": self.triggered_from,
|
||||
"version": self.version,
|
||||
"graph": self.graph_dict,
|
||||
"inputs": self.inputs_dict,
|
||||
"status": self.status,
|
||||
"outputs": self.outputs_dict,
|
||||
"error": self.error,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
"total_steps": self.total_steps,
|
||||
"created_by_role": self.created_by_role,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
"finished_at": self.finished_at,
|
||||
"exceptions_count": self.exceptions_count,
|
||||
}
|
||||
def to_dict(self) -> WorkflowRunDict:
|
||||
return WorkflowRunDict(
|
||||
id=self.id,
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
type=self.type,
|
||||
triggered_from=self.triggered_from,
|
||||
version=self.version,
|
||||
graph=self.graph_dict,
|
||||
inputs=self.inputs_dict,
|
||||
status=self.status,
|
||||
outputs=self.outputs_dict,
|
||||
error=self.error,
|
||||
elapsed_time=self.elapsed_time,
|
||||
total_tokens=self.total_tokens,
|
||||
total_steps=self.total_steps,
|
||||
created_by_role=self.created_by_role,
|
||||
created_by=self.created_by,
|
||||
created_at=self.created_at,
|
||||
finished_at=self.finished_at,
|
||||
exceptions_count=self.exceptions_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
|
||||
12
api/providers/README.md
Normal file
12
api/providers/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Providers
|
||||
|
||||
This directory holds **optional workspace packages** that plug into Dify’s API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions.
|
||||
|
||||
## Developing Providers
|
||||
|
||||
- [VDB Providers](vdb/README.md)
|
||||
|
||||
## Tests
|
||||
|
||||
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
|
||||
|
||||
58
api/providers/vdb/README.md
Normal file
58
api/providers/vdb/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# VDB providers
|
||||
|
||||
This directory contains all VDB providers.
|
||||
|
||||
## Architecture
|
||||
1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins.
|
||||
2. **Each provider** (`api/providers/vdb/<backend>/`) implements those contracts and registers an entry point.
|
||||
3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`).
|
||||
|
||||
### Interfaces
|
||||
|
||||
| Piece | Role |
|
||||
|--------|----------|
|
||||
| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. |
|
||||
| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. |
|
||||
| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. |
|
||||
| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. |
|
||||
|
||||
The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation.
|
||||
|
||||
### Entry point name must match the vector type string
|
||||
|
||||
Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"` → `tidb_on_qdrant`).
|
||||
|
||||
In `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory"
|
||||
```
|
||||
|
||||
The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`.
|
||||
|
||||
### How registration works
|
||||
|
||||
1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache.
|
||||
2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**.
|
||||
3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance).
|
||||
4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**.
|
||||
|
||||
After you change a provider’s `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environment’s dist-info matches the project metadata.
|
||||
|
||||
### Package layout (VDB)
|
||||
|
||||
Each backend usually follows:
|
||||
|
||||
- `api/providers/vdb/<backend>/pyproject.toml` — project name `dify-vdb-<backend>`, dependencies, entry points.
|
||||
- `api/providers/vdb/<backend>/src/dify_vdb_<python_package>/` — implementation (e.g. `PGVector`, `PGVectorFactory`).
|
||||
|
||||
See `vdb/pgvector/` as a reference implementation.
|
||||
|
||||
### Wiring a new backend into the API workspace
|
||||
|
||||
The API uses a **uv workspace** (`api/pyproject.toml`):
|
||||
|
||||
1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members.
|
||||
2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`.
|
||||
3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle.
|
||||
22
api/providers/vdb/conftest.py
Normal file
22
api/providers/vdb/conftest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_mock_redis():
|
||||
"""Ensure redis_client has a backing client so __getattr__ never raises."""
|
||||
if ext_redis.redis_client._client is None:
|
||||
ext_redis.redis_client.initialize(MagicMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None))
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock)
|
||||
13
api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml
Normal file
13
api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-alibabacloud-mysql"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"mysql-connector-python>=9.3.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,9 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
|
||||
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
|
||||
|
||||
|
||||
def test_validate_distance_function_accepts_supported_values():
|
||||
@@ -3,11 +3,11 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVector,
|
||||
AlibabaCloudMySQLVectorConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
try:
|
||||
@@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
# Sample embeddings
|
||||
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test AlibabaCloudMySQLVector initialization."""
|
||||
# Mock the connection pool
|
||||
@@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||
assert alibabacloud_mysql_vector.pool is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation."""
|
||||
# Mock Redis operations
|
||||
@@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_success(self, mock_pool_class):
|
||||
"""Test successful vector support check."""
|
||||
# Mock the connection pool
|
||||
@@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
assert vector_store is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_failure(self, mock_pool_class):
|
||||
"""Test vector support check failure."""
|
||||
# Mock the connection pool
|
||||
@@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||
"""Test vector support check with function not found error."""
|
||||
# Mock the connection pool
|
||||
@@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||
"""Test creating documents with embeddings."""
|
||||
# Setup mocks
|
||||
@@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_add_texts(self, mock_pool_class):
|
||||
"""Test adding texts to the vector store."""
|
||||
# Mock the connection pool
|
||||
@@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(result) == 2
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_text_exists(self, mock_pool_class):
|
||||
"""Test checking if text exists."""
|
||||
# Mock the connection pool
|
||||
@@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "SELECT id FROM" in last_call[0][0]
|
||||
assert last_call[0][1] == ("doc1",)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_text_not_exists(self, mock_pool_class):
|
||||
"""Test checking if text does not exist."""
|
||||
# Mock the connection pool
|
||||
@@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert not exists
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_get_by_ids(self, mock_pool_class):
|
||||
"""Test getting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
@@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert docs[1].page_content == "Test document 2"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test getting documents with empty ID list."""
|
||||
# Mock the connection pool
|
||||
@@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids(self, mock_pool_class):
|
||||
"""Test deleting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
@@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "DELETE FROM" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test deleting with empty ID list."""
|
||||
# Mock the connection pool
|
||||
@@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||
"""Test deleting when table doesn't exist."""
|
||||
# Mock the connection pool
|
||||
@@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
# Should not raise an exception
|
||||
vector_store.delete_by_ids(["doc1"])
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||
"""Test deleting documents by metadata field."""
|
||||
# Mock the connection pool
|
||||
@@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||
"""Test vector search with cosine distance."""
|
||||
# Mock the connection pool
|
||||
@@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||
assert docs[0].metadata["distance"] == 0.1
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||
"""Test vector search with euclidean distance."""
|
||||
config = AlibabaCloudMySQLVectorConfig(
|
||||
@@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||
"""Test vector search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
@@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
search_call = search_calls[0]
|
||||
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||
"""Test vector search with score threshold."""
|
||||
# Mock the connection pool
|
||||
@@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "High similarity document"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||
"""Test vector search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
@@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text(self, mock_pool_class):
|
||||
"""Test full-text search."""
|
||||
# Mock the connection pool
|
||||
@@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert docs[0].page_content == "This document contains machine learning content"
|
||||
assert docs[0].metadata["score"] == 1.5
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||
"""Test full-text search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
@@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
search_call = search_calls[0]
|
||||
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||
"""Test full-text search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
@@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_collection(self, mock_pool_class):
|
||||
"""Test deleting the entire collection."""
|
||||
# Mock the connection pool
|
||||
@@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
drop_call = drop_calls[0]
|
||||
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_unsupported_distance_function(self, mock_pool_class):
|
||||
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||
# Test that creating config with unsupported distance function raises ValidationError
|
||||
15
api/providers/vdb/vdb-analyticdb/pyproject.toml
Normal file
15
api/providers/vdb/vdb-analyticdb/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "dify-vdb-analyticdb"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"alibabacloud_gpdb20160503~=5.2.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"clickhouse-connect~=0.15.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-analyticdb)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -2,16 +2,16 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
|
||||
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
@@ -1,12 +1,12 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module
|
||||
import pytest
|
||||
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@ from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import (
|
||||
AnalyticdbVectorBySql,
|
||||
AnalyticdbVectorBySqlConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
13
api/providers/vdb/vdb-baidu/pyproject.toml
Normal file
13
api/providers/vdb/vdb-baidu/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-baidu"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"pymochow==2.4.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-baidu)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,6 @@
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
|
||||
pytest_plugins = (
|
||||
"tests.integration_tests.vdb.test_vector_store",
|
||||
"tests.integration_tests.vdb.__mock.baiduvectordb",
|
||||
)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
|
||||
|
||||
class BaiduVectorTest(AbstractVectorTest):
|
||||
@@ -124,7 +124,7 @@ def _build_fake_pymochow_modules():
|
||||
def baidu_module(monkeypatch):
|
||||
for name, module in _build_fake_pymochow_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
import core.rag.datasource.vdb.baidu.baidu_vector as module
|
||||
import dify_vdb_baidu.baidu_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
13
api/providers/vdb/vdb-chroma/pyproject.toml
Normal file
13
api/providers/vdb/vdb-chroma/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-chroma"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"chromadb==0.5.20",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-chroma)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,13 +1,11 @@
|
||||
import chromadb
|
||||
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class ChromaVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -47,7 +47,7 @@ def _build_fake_chroma_modules():
|
||||
def chroma_module(monkeypatch):
|
||||
fake_chroma = _build_fake_chroma_modules()
|
||||
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
|
||||
import core.rag.datasource.vdb.chroma.chroma_vector as module
|
||||
import dify_vdb_chroma.chroma_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
@@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
14
api/providers/vdb/vdb-clickzetta/pyproject.toml
Normal file
14
api/providers/vdb/vdb-clickzetta/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-clickzetta"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-clickzetta)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -2,10 +2,10 @@ import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
@@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
def vector_store(self, setup_mock_redis):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
@@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
batch_size=10,
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
@@ -3,16 +3,19 @@
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
logger.info("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
@@ -25,100 +28,93 @@ def test_clickzetta_connection():
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
logger.info("✓ Connection test: %s", result)
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"])
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
logger.info("✓ Table structure for %s:", test_collection)
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
logger.info(" - %s: %s", col[0], col[1])
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
logger.info("✓ Indexes on %s:", test_collection)
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
logger.info(" - %s", idx)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
except Exception:
|
||||
logger.exception("✗ Connection test failed")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
logger.info("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
logger.info("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
logger.exception("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
logger.info("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
except Exception:
|
||||
logger.exception("✗ API test failed")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
logger.info("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"metadata": "VARCHAR",
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
logger.info("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
logger.info(" - %s: %s", col, dtype)
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
logger.info("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
logger.info(" - %s", field)
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
logger.info("\n✓ Index requirements:")
|
||||
logger.info(" - Vector index (HNSW) on 'vector' column")
|
||||
logger.info(" - Full-text index on 'page_content' (optional)")
|
||||
logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
logger.info(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
logger.info("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
@@ -131,33 +127,34 @@ def main():
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
except Exception:
|
||||
logger.exception("\n✗ %s crashed", test_name)
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
logger.info("\n%s", "=" * 50)
|
||||
logger.info("Test Summary:")
|
||||
logger.info("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
logger.info("%s: %s", test_name, status)
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
logger.info("\nTotal: %s/%s tests passed", passed, total)
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
logger.info("\nNext steps:")
|
||||
logger.info(
|
||||
"1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d"
|
||||
)
|
||||
logger.info("2. Access Dify at http://localhost:3000")
|
||||
logger.info("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
logger.error("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def _build_fake_clickzetta_module():
|
||||
@pytest.fixture
|
||||
def clickzetta_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
|
||||
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
|
||||
import dify_vdb_clickzetta.clickzetta_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-couchbase/pyproject.toml
Normal file
14
api/providers/vdb/vdb-couchbase/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-couchbase"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"couchbase~=4.6.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-couchbase)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
@@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "healthy":
|
||||
print(f"{service_name} is healthy!")
|
||||
logger.info("%s is healthy!", service_name)
|
||||
return True
|
||||
else:
|
||||
print(f"Waiting for {service_name} to be healthy...")
|
||||
logger.info("Waiting for %s to be healthy...", service_name)
|
||||
time.sleep(10)
|
||||
raise TimeoutError(f"{service_name} did not become healthy in time")
|
||||
|
||||
@@ -154,7 +154,7 @@ def couchbase_module(monkeypatch):
|
||||
for name, module in _build_fake_couchbase_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
|
||||
import dify_vdb_couchbase.couchbase_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
15
api/providers/vdb/vdb-elasticsearch/pyproject.toml
Normal file
15
api/providers/vdb/vdb-elasticsearch/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "dify-vdb-elasticsearch"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"elasticsearch==8.14.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-elasticsearch)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory"
|
||||
elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -4,14 +4,14 @@ from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from dify_vdb_elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
|
||||
import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module
|
||||
import dify_vdb_elasticsearch.elasticsearch_vector as base_module
|
||||
|
||||
importlib.reload(base_module)
|
||||
return importlib.reload(ja_module)
|
||||
@@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
|
||||
import dify_vdb_elasticsearch.elasticsearch_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-hologres/pyproject.toml
Normal file
14
api/providers/vdb/vdb-hologres/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-hologres"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"holo-search-sdk>=0.4.2",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-hologres)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import holo_search_sdk as holo # type: ignore
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
@@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory):
|
||||
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
|
||||
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
|
||||
schema_name=dify_config.HOLOGRES_SCHEMA,
|
||||
tokenizer=dify_config.HOLOGRES_TOKENIZER,
|
||||
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
|
||||
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
|
||||
tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER),
|
||||
distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD),
|
||||
base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE),
|
||||
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
|
||||
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
|
||||
),
|
||||
@@ -7,13 +7,10 @@ import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from psycopg import sql as psql
|
||||
|
||||
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
|
||||
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
|
||||
|
||||
|
||||
class MockSearchQuery:
|
||||
"""Mock query builder for search_vector and search_text results."""
|
||||
|
||||
def __init__(self, table_name: str, search_type: str):
|
||||
self._table_name = table_name
|
||||
self._search_type = search_type
|
||||
@@ -32,17 +29,13 @@ class MockSearchQuery:
|
||||
return self
|
||||
|
||||
def _apply_filter(self, row: dict[str, Any]) -> bool:
|
||||
"""Apply the filter SQL to check if a row matches."""
|
||||
if self._filter_sql is None:
|
||||
return True
|
||||
|
||||
# Extract literals (the document IDs) from the filter SQL
|
||||
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
|
||||
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
|
||||
if not literals:
|
||||
return True
|
||||
|
||||
# Get the document_id from the row's meta field
|
||||
meta = row.get("meta", "{}")
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
@@ -54,22 +47,17 @@ class MockSearchQuery:
|
||||
data = _mock_tables.get(self._table_name, {})
|
||||
results = []
|
||||
for row in list(data.values())[: self._limit_val]:
|
||||
# Apply filter if present
|
||||
if not self._apply_filter(row):
|
||||
continue
|
||||
|
||||
if self._search_type == "vector":
|
||||
# row format expected by _process_vector_results: (distance, id, text, meta)
|
||||
results.append((0.1, row["id"], row["text"], row["meta"]))
|
||||
else:
|
||||
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
|
||||
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
|
||||
return results
|
||||
|
||||
|
||||
class MockTable:
|
||||
"""Mock table object returned by client.open_table()."""
|
||||
|
||||
def __init__(self, table_name: str):
|
||||
self._table_name = table_name
|
||||
|
||||
@@ -97,7 +85,6 @@ class MockTable:
|
||||
|
||||
|
||||
def _extract_sql_template(query) -> str:
|
||||
"""Extract the SQL template string from a psycopg Composed object."""
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
if isinstance(part, psql.SQL):
|
||||
@@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str:
|
||||
|
||||
|
||||
def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
"""Extract Identifier and Literal values from a psycopg Composed object."""
|
||||
values: list[Any] = []
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
@@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
elif isinstance(part, psql.Literal):
|
||||
values.append(("literal", part._obj))
|
||||
elif isinstance(part, psql.Composed):
|
||||
# Handles SQL(...).join(...) for IN clauses
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
@@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
|
||||
|
||||
class MockHologresClient:
|
||||
"""Mock holo_search_sdk client that stores data in memory."""
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@@ -141,21 +124,18 @@ class MockHologresClient:
|
||||
params = _extract_identifiers_and_literals(query)
|
||||
|
||||
if "CREATE TABLE" in template.upper():
|
||||
# Extract table name from first identifier
|
||||
table_name = next((v for t, v in params if t == "ident"), "unknown")
|
||||
if table_name not in _mock_tables:
|
||||
_mock_tables[table_name] = {}
|
||||
return None
|
||||
|
||||
if "SELECT 1" in template:
|
||||
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
doc_id = next((v for t, v in params if t == "literal"), "")
|
||||
data = _mock_tables.get(table_name, {})
|
||||
return [(1,)] if doc_id in data else []
|
||||
|
||||
if "SELECT id" in template:
|
||||
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
literals = [v for t, v in params if t == "literal"]
|
||||
key = literals[0] if len(literals) > 0 else ""
|
||||
@@ -166,12 +146,10 @@ class MockHologresClient:
|
||||
if "DELETE" in template.upper():
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
if "id IN" in template:
|
||||
# delete_by_ids
|
||||
ids_to_delete = [v for t, v in params if t == "literal"]
|
||||
for did in ids_to_delete:
|
||||
_mock_tables.get(table_name, {}).pop(did, None)
|
||||
elif "meta->>" in template:
|
||||
# delete_by_metadata_field
|
||||
literals = [v for t, v in params if t == "literal"]
|
||||
key = literals[0] if len(literals) > 0 else ""
|
||||
value = literals[1] if len(literals) > 1 else ""
|
||||
@@ -190,7 +168,6 @@ class MockHologresClient:
|
||||
|
||||
|
||||
def mock_connect(**kwargs):
|
||||
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
|
||||
return MockHologresClient()
|
||||
|
||||
|
||||
@@ -2,16 +2,11 @@ import os
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
|
||||
pytest_plugins = (
|
||||
"tests.integration_tests.vdb.test_vector_store",
|
||||
"tests.integration_tests.vdb.__mock.hologres",
|
||||
)
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
@@ -42,7 +42,7 @@ def hologres_module(monkeypatch):
|
||||
for name, module in _build_fake_hologres_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.hologres.hologres_vector as module
|
||||
import dify_vdb_hologres.hologres_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-huawei-cloud/pyproject.toml
Normal file
14
api/providers/vdb/vdb-huawei-cloud/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-huawei-cloud"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"elasticsearch==8.14.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-huawei-cloud)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,6 @@
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
from dify_vdb_huawei_cloud.huawei_cloud_vector import HuaweiCloudVector, HuaweiCloudVectorConfig
|
||||
|
||||
pytest_plugins = (
|
||||
"tests.integration_tests.vdb.test_vector_store",
|
||||
"tests.integration_tests.vdb.__mock.huaweicloudvectordb",
|
||||
)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
|
||||
|
||||
class HuaweiCloudVectorTest(AbstractVectorTest):
|
||||
@@ -33,7 +33,7 @@ def huawei_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.huawei.huawei_cloud_vector as module
|
||||
import dify_vdb_huawei_cloud.huawei_cloud_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-iris/pyproject.toml
Normal file
14
api/providers/vdb/vdb-iris/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-iris"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"intersystems-irispython>=5.1.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-iris)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
iris = "dify_vdb_iris.iris_vector:IrisVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,12 +1,11 @@
|
||||
"""Integration tests for IRIS vector database."""
|
||||
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVector, IrisVectorConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_iris.iris_vector import IrisVector, IrisVectorConfig
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class IrisVectorTest(AbstractVectorTest):
|
||||
"""Test suite for IRIS vector store implementation."""
|
||||
@@ -26,7 +26,7 @@ def _build_fake_iris_module():
|
||||
def iris_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "iris", _build_fake_iris_module())
|
||||
|
||||
import core.rag.datasource.vdb.iris.iris_vector as module
|
||||
import dify_vdb_iris.iris_vector as module
|
||||
|
||||
reloaded = importlib.reload(module)
|
||||
reloaded._pool_instance = None
|
||||
15
api/providers/vdb/vdb-lindorm/pyproject.toml
Normal file
15
api/providers/vdb/vdb-lindorm/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "dify-vdb-lindorm"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"opensearch-py==3.1.0",
|
||||
"tenacity>=8.0.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-lindorm)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
lindorm = "dify_vdb_lindorm.lindorm_vector:LindormVectorStoreFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
|
||||
from dify_vdb_lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
|
||||
|
||||
|
||||
class Config:
|
||||
@@ -51,7 +51,7 @@ def lindorm_module(monkeypatch):
|
||||
for name, module in _build_fake_opensearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.lindorm.lindorm_vector as module
|
||||
import dify_vdb_lindorm.lindorm_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-matrixone/pyproject.toml
Normal file
14
api/providers/vdb/vdb-matrixone/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-matrixone"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"mo-vector~=0.1.13",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-matrixone)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
matrixone = "dify_vdb_matrixone.matrixone_vector:MatrixoneVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,9 @@
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class MatrixoneVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -36,7 +36,7 @@ def matrixone_module(monkeypatch):
|
||||
for name, module in _build_fake_mo_vector_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.matrixone.matrixone_vector as module
|
||||
import dify_vdb_matrixone.matrixone_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-milvus/pyproject.toml
Normal file
14
api/providers/vdb/vdb-milvus/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-milvus"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"pymilvus~=2.6.12",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-milvus)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
milvus = "dify_vdb_milvus.milvus_vector:MilvusVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,11 +1,10 @@
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_milvus.milvus_vector import MilvusConfig, MilvusVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class MilvusVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -103,7 +103,7 @@ def milvus_module(monkeypatch):
|
||||
for name, module in _build_fake_pymilvus_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.milvus.milvus_vector as module
|
||||
import dify_vdb_milvus.milvus_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-myscale/pyproject.toml
Normal file
14
api/providers/vdb/vdb-myscale/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-myscale"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"clickhouse-connect~=0.15.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-myscale)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
myscale = "dify_vdb_myscale.myscale_vector:MyScaleVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,9 @@
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleConfig, MyScaleVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_myscale.myscale_vector import MyScaleConfig, MyScaleVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class MyScaleVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -42,7 +42,7 @@ def myscale_module(monkeypatch):
|
||||
fake_module = _build_fake_clickhouse_connect_module()
|
||||
monkeypatch.setitem(sys.modules, "clickhouse_connect", fake_module)
|
||||
|
||||
import core.rag.datasource.vdb.myscale.myscale_vector as module
|
||||
import dify_vdb_myscale.myscale_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
16
api/providers/vdb/vdb-oceanbase/pyproject.toml
Normal file
16
api/providers/vdb/vdb-oceanbase/pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[project]
|
||||
name = "dify-vdb-oceanbase"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"pyobvector~=0.2.17",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-oceanbase)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
oceanbase = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
|
||||
seekdb = "dify_vdb_oceanbase.oceanbase_vector:OceanBaseVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -2,11 +2,12 @@
|
||||
Benchmark: OceanBase vector store — old (single-row) vs new (batch) insertion,
|
||||
metadata query with/without functional index, and vector search across metrics.
|
||||
|
||||
Usage:
|
||||
uv run --project api python -m tests.integration_tests.vdb.oceanbase.bench_oceanbase
|
||||
Usage (from repo root):
|
||||
uv run --project api python api/packages/dify-vdb-oceanbase/tests/bench_oceanbase.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import statistics
|
||||
import time
|
||||
@@ -16,6 +17,8 @@ from pyobvector import VECTOR, ObVecClient, cosine_distance, inner_product, l2_d
|
||||
from sqlalchemy import JSON, Column, String, text
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -114,7 +117,7 @@ def bench_metadata_query(client, table, doc_id, with_index=False):
|
||||
try:
|
||||
client.perform_raw_text_sql(f"CREATE INDEX idx_metadata_doc_id ON `{table}` ((metadata->>'$.document_id'))")
|
||||
except Exception:
|
||||
pass # already exists
|
||||
logger.debug("Index idx_metadata_doc_id already exists, skipping creation")
|
||||
|
||||
sql = text(f"SELECT id FROM `{table}` WHERE metadata->>'$.document_id' = :val")
|
||||
times = []
|
||||
@@ -164,11 +167,11 @@ def main():
|
||||
client = _make_client()
|
||||
client_pooled = _make_client(pool_size=5, max_overflow=10, pool_recycle=3600, pool_pre_ping=True)
|
||||
|
||||
print("=" * 70)
|
||||
print("OceanBase Vector Store — Performance Benchmark")
|
||||
print(f" Endpoint : {HOST}:{PORT}")
|
||||
print(f" Vec dim : {VEC_DIM}")
|
||||
print("=" * 70)
|
||||
logger.info("=" * 70)
|
||||
logger.info("OceanBase Vector Store — Performance Benchmark")
|
||||
logger.info(" Endpoint : %s:%s", HOST, PORT)
|
||||
logger.info(" Vec dim : %s", VEC_DIM)
|
||||
logger.info("=" * 70)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Insertion benchmark
|
||||
@@ -187,10 +190,10 @@ def main():
|
||||
t_batch = bench_insert_batch(client_pooled, tbl_batch, rows, batch_size=100)
|
||||
|
||||
speedup = t_single / t_batch if t_batch > 0 else float("inf")
|
||||
print(f"\n[Insert {n_docs} docs]")
|
||||
print(f" Single-row : {t_single:.2f}s")
|
||||
print(f" Batch(100) : {t_batch:.2f}s")
|
||||
print(f" Speedup : {speedup:.1f}x")
|
||||
logger.info("\n[Insert %s docs]", n_docs)
|
||||
logger.info(" Single-row : %.2fs", t_single)
|
||||
logger.info(" Batch(100) : %.2fs", t_batch)
|
||||
logger.info(" Speedup : %.1fx", speedup)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Metadata query benchmark (use the 1000-doc batch table)
|
||||
@@ -203,16 +206,16 @@ def main():
|
||||
res = conn.execute(text(f"SELECT metadata->>'$.document_id' FROM `{tbl_meta}` LIMIT 1"))
|
||||
doc_id_1000 = res.fetchone()[0]
|
||||
|
||||
print("\n[Metadata filter query — 1000 rows, by document_id]")
|
||||
logger.info("\n[Metadata filter query — 1000 rows, by document_id]")
|
||||
times_no_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=False)
|
||||
print(f" Without index : {_fmt(times_no_idx)}")
|
||||
logger.info(" Without index : %s", _fmt(times_no_idx))
|
||||
times_with_idx = bench_metadata_query(client, tbl_meta, doc_id_1000, with_index=True)
|
||||
print(f" With index : {_fmt(times_with_idx)}")
|
||||
logger.info(" With index : %s", _fmt(times_with_idx))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Vector search benchmark — across metrics
|
||||
# ------------------------------------------------------------------
|
||||
print("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
|
||||
logger.info("\n[Vector search — top-10, 20 queries each, on 1000 rows]")
|
||||
|
||||
for metric in ["l2", "cosine", "inner_product"]:
|
||||
tbl_vs = f"bench_vs_{metric}"
|
||||
@@ -222,7 +225,7 @@ def main():
|
||||
rows_vs, _ = _gen_rows(1000)
|
||||
bench_insert_batch(client_pooled, tbl_vs, rows_vs, batch_size=100)
|
||||
times = bench_vector_search(client_pooled, tbl_vs, metric, topk=10, n_queries=20)
|
||||
print(f" {metric:15s}: {_fmt(times)}")
|
||||
logger.info(" %-15s: %s", metric, _fmt(times))
|
||||
_drop(client_pooled, tbl_vs)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -232,9 +235,9 @@ def main():
|
||||
_drop(client, f"bench_single_{n}")
|
||||
_drop(client, f"bench_batch_{n}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmark complete.")
|
||||
print("=" * 70)
|
||||
logger.info("\n%s", "=" * 70)
|
||||
logger.info("Benchmark complete.")
|
||||
logger.info("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,15 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import (
|
||||
from dify_vdb_oceanbase.oceanbase_vector import (
|
||||
OceanBaseVector,
|
||||
OceanBaseVectorConfig,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_vector():
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user