mirror of
https://github.com/langgenius/dify.git
synced 2026-04-13 12:02:44 +00:00
Compare commits
174 Commits
dependabot
...
feat/evalu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1712a2732a | ||
|
|
46bc76bae3 | ||
|
|
ae898652b2 | ||
|
|
c34f67495c | ||
|
|
815c536e05 | ||
|
|
fc64427ae1 | ||
|
|
11c518478e | ||
|
|
e823635ce1 | ||
|
|
98e74c8fde | ||
|
|
29bfa33d59 | ||
|
|
8c6dda125f | ||
|
|
f6047aafe8 | ||
|
|
3ead0beeb1 | ||
|
|
2108c44c8b | ||
|
|
b0079e55b4 | ||
|
|
d9f54f8bd7 | ||
|
|
dce5715982 | ||
|
|
5a446f8200 | ||
|
|
f4d5e2f43d | ||
|
|
9121f24181 | ||
|
|
7dd507af04 | ||
|
|
3b9aad2ba7 | ||
|
|
ea9f74b581 | ||
|
|
e37aaa482d | ||
|
|
a3170f744c | ||
|
|
ced3780787 | ||
|
|
6faf26683c | ||
|
|
8ac9cbf733 | ||
|
|
098ed34469 | ||
|
|
6cf4d1002f | ||
|
|
a111d56ea3 | ||
|
|
8436470fcb | ||
|
|
17da0e4146 | ||
|
|
ea41e9ab4e | ||
|
|
5770b5feef | ||
|
|
b5259a3a85 | ||
|
|
596559efc9 | ||
|
|
ea910b8e7d | ||
|
|
e51af66d95 | ||
|
|
f93b287949 | ||
|
|
627fbd2e86 | ||
|
|
e4c056a57a | ||
|
|
23291398ec | ||
|
|
79fc352a5a | ||
|
|
8b6b3cddea | ||
|
|
d1ca468c1e | ||
|
|
ce28ad771c | ||
|
|
ba951b01de | ||
|
|
670ab16ea1 | ||
|
|
4680535ecd | ||
|
|
f96e63460e | ||
|
|
2df79c0404 | ||
|
|
acef9630d5 | ||
|
|
12c3b2e0cd | ||
|
|
577707ae50 | ||
|
|
03325e9750 | ||
|
|
a7ef8f9c12 | ||
|
|
40284d9f95 | ||
|
|
5efe8b8bd7 | ||
|
|
8dc6d736ee | ||
|
|
5316372772 | ||
|
|
4d1499ef75 | ||
|
|
0438285277 | ||
|
|
4879ea5cd5 | ||
|
|
2a1761ac06 | ||
|
|
c29245c1cb | ||
|
|
5069694bba | ||
|
|
d1a80a85c0 | ||
|
|
5c93d74dec | ||
|
|
e52dbd49be | ||
|
|
ccc8a5f278 | ||
|
|
cfb5b9dfea | ||
|
|
73d95245f8 | ||
|
|
fb91984fcb | ||
|
|
29cb1fa12e | ||
|
|
78240ed199 | ||
|
|
8f8707fd77 | ||
|
|
ed3db06154 | ||
|
|
7c05a68876 | ||
|
|
6cfc0dd8e1 | ||
|
|
81baeae5c4 | ||
|
|
a3010bdc0b | ||
|
|
8133e550ed | ||
|
|
2bb0eab636 | ||
|
|
5311b5d00d | ||
|
|
9b02ccdd12 | ||
|
|
231783eebe | ||
|
|
756606f478 | ||
|
|
6651c1c5da | ||
|
|
61e257b2a8 | ||
|
|
3ac4caf735 | ||
|
|
268ae1751d | ||
|
|
015cbf850b | ||
|
|
873e13c2fb | ||
|
|
688bf7e7a1 | ||
|
|
a6ffff3b39 | ||
|
|
023fc55bd5 | ||
|
|
351b909a53 | ||
|
|
6bec4f65c9 | ||
|
|
74f87ce152 | ||
|
|
92c472ccc7 | ||
|
|
b92b8becd1 | ||
|
|
23d0d6a65d | ||
|
|
1660067d6e | ||
|
|
0642475b85 | ||
|
|
8cb634c9bc | ||
|
|
768b41c3cf | ||
|
|
ca88516d54 | ||
|
|
871a2a149f | ||
|
|
60e381eff0 | ||
|
|
768b3eb6f9 | ||
|
|
2f88da4a6d | ||
|
|
a8cdf6964c | ||
|
|
985c3db4fd | ||
|
|
9636472db7 | ||
|
|
0ad268aa7d | ||
|
|
a4ea33167d | ||
|
|
0f13aabea8 | ||
|
|
1e76ef5ccb | ||
|
|
e6e3229d17 | ||
|
|
dccf8e723a | ||
|
|
c41ba7d627 | ||
|
|
a6e9316de3 | ||
|
|
559d326cbd | ||
|
|
abedf2506f | ||
|
|
d01428b5bc | ||
|
|
0de1f17e5c | ||
|
|
17d07a5a43 | ||
|
|
3bdbea99a3 | ||
|
|
b7683aedb1 | ||
|
|
515036e758 | ||
|
|
22b382527f | ||
|
|
2cfe4b5b86 | ||
|
|
6876c8041c | ||
|
|
7de45584ce | ||
|
|
5572d7c7e8 | ||
|
|
db0a2fe52e | ||
|
|
f0ae8d6167 | ||
|
|
2514e181ba | ||
|
|
be2e6e9a14 | ||
|
|
875e2eac1b | ||
|
|
c3c73ceb1f | ||
|
|
6318bf0a2a | ||
|
|
5e1f252046 | ||
|
|
df3b960505 | ||
|
|
26bc108bf1 | ||
|
|
a5cff32743 | ||
|
|
d418dd8eec | ||
|
|
61702fe346 | ||
|
|
43f0c780c3 | ||
|
|
30ebf2bfa9 | ||
|
|
7e3027b5f7 | ||
|
|
b3acf83090 | ||
|
|
36c3d6e48a | ||
|
|
f782ac6b3c | ||
|
|
feef2dd1fa | ||
|
|
a716d8789d | ||
|
|
6816f89189 | ||
|
|
bfcac64a9d | ||
|
|
664eb601a2 | ||
|
|
8e5cc4e0aa | ||
|
|
9f28575903 | ||
|
|
4b9a26a5e6 | ||
|
|
7b85adf1cc | ||
|
|
c964708ebe | ||
|
|
883eb498c0 | ||
|
|
4d3738d225 | ||
|
|
dd0dee739d | ||
|
|
4d19914fcb | ||
|
|
887c7710e9 | ||
|
|
7a722773c7 | ||
|
|
a763aff58b | ||
|
|
c1011f4e5c | ||
|
|
f7afa103a5 |
100
.github/dependabot.yml
vendored
100
.github/dependabot.yml
vendored
@@ -1,106 +1,6 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
- package-ecosystem: "pip"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
open-pull-requests-limit: 10
|
||||
|
||||
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@@ -18,7 +18,7 @@
|
||||
## Checklist
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [ ] I've updated the documentation accordingly.
|
||||
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
|
||||
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@@ -92,6 +92,7 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
6
.github/workflows/stale.yml
vendored
6
.github/workflows/stale.yml
vendored
@@ -23,8 +23,8 @@ jobs:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
|
||||
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'
|
||||
any-of-labels: '🌚 invalid,🙋♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
|
||||
@@ -77,7 +77,7 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
vp staged
|
||||
pnpm exec vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
|
||||
@@ -21,8 +21,9 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
@@ -341,11 +341,10 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -42,17 +41,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -23,9 +24,9 @@ class ConversationRenamePayload(BaseModel):
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
|
||||
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
@@ -69,11 +70,35 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
# --- Dataset schemas ---
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
from datetime import datetime
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from pydantic import field_validator
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@@ -16,21 +20,31 @@ from services.api_token_service import ApiTokenCache
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
"type": fields.String,
|
||||
"token": fields.String,
|
||||
"last_used_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
token: str
|
||||
last_used_at: int | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@@ -54,7 +68,6 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
@@ -66,9 +79,8 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
@@ -100,7 +112,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 201
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@@ -147,7 +159,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
@@ -155,7 +167,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
@@ -187,7 +199,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
@@ -195,7 +207,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
|
||||
@@ -25,7 +25,13 @@ from fields.annotation_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
UpdateAnnotationSettingArgs,
|
||||
UpsertAnnotationArgs,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -120,7 +126,12 @@ class AnnotationReplyActionApi(Resource):
|
||||
args = AnnotationReplyPayload.model_validate(console_ns.payload)
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": args.score_threshold,
|
||||
"embedding_provider_name": args.embedding_provider_name,
|
||||
"embedding_model_name": args.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_id)
|
||||
return result, 200
|
||||
@@ -161,7 +172,8 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@@ -237,8 +249,16 @@ class AnnotationApi(Resource):
|
||||
def post(self, app_id):
|
||||
app_id = str(app_id)
|
||||
args = CreateAnnotationPayload.model_validate(console_ns.payload)
|
||||
data = args.model_dump(exclude_none=True)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
|
||||
upsert_args: UpsertAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
upsert_args["answer"] = args.answer
|
||||
if args.content is not None:
|
||||
upsert_args["content"] = args.content
|
||||
if args.message_id is not None:
|
||||
upsert_args["message_id"] = args.message_id
|
||||
if args.question is not None:
|
||||
upsert_args["question"] = args.question
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@@ -315,9 +335,12 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(
|
||||
args.model_dump(exclude_none=True), app_id, annotation_id
|
||||
)
|
||||
update_args: UpdateAnnotationArgs = {}
|
||||
if args.answer is not None:
|
||||
update_args["answer"] = args.answer
|
||||
if args.question is not None:
|
||||
update_args["question"] = args.question
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
|
||||
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -10,35 +11,15 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppImportPayload(BaseModel):
|
||||
mode: str = Field(..., description="Import mode")
|
||||
@@ -52,18 +33,18 @@ class AppImportPayload(BaseModel):
|
||||
app_id: str | None = Field(None)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@@ -104,10 +85,11 @@ class AppImportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
@@ -128,11 +110,11 @@ class AppImportConfirmApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@@ -32,8 +36,33 @@ class MCPServerUpdatePayload(BaseModel):
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class AppMCPServerResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
server_code: str
|
||||
description: str
|
||||
status: str
|
||||
parameters: dict[str, Any] | list[Any] | str
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def _parse_json_string(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
@@ -41,27 +70,27 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
return server
|
||||
if server is None:
|
||||
return {}
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -82,20 +111,19 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
@@ -118,7 +146,7 @@ class AppMCPServerController(Resource):
|
||||
except ValueError:
|
||||
raise ValueError("Invalid status")
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
@@ -126,13 +154,12 @@ class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -145,4 +172,4 @@ class AppMCPServerRefreshController(Resource):
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@@ -15,13 +16,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
title: str | None = Field(default=None)
|
||||
@@ -49,13 +48,26 @@ class AppSiteUpdatePayload(BaseModel):
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AppSiteUpdatePayload.__name__,
|
||||
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
class AppSiteResponse(ResponseModel):
|
||||
app_id: str
|
||||
access_token: str | None = Field(default=None, validation_alias="code")
|
||||
code: str | None = None
|
||||
title: str
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
description: str | None = None
|
||||
default_language: str
|
||||
customize_domain: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
customize_token_strategy: str
|
||||
prompt_public: bool
|
||||
show_workflow_steps: bool
|
||||
use_icon_as_answer_icon: bool
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
@@ -64,7 +76,7 @@ class AppSite(Resource):
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@@ -72,7 +84,6 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@@ -106,7 +117,7 @@ class AppSite(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
@@ -114,7 +125,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@@ -122,7 +133,6 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@@ -135,4 +145,4 @@ class AppSiteAccessTokenReset(Resource):
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
@@ -11,8 +12,6 @@ from libs.helper import EmailStr, timezone
|
||||
from models import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ActivateCheckQuery(BaseModel):
|
||||
workspace_id: str | None = Field(default=None)
|
||||
@@ -39,8 +38,16 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
for model in (ActivateCheckQuery, ActivatePayload):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
@@ -51,13 +58,7 @@ class ActivateCheckApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
"data": fields.Raw(description="Activation data if valid"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationCheckResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
@@ -95,12 +96,7 @@ class ActivateApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
},
|
||||
),
|
||||
console_ns.models[ActivationResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
|
||||
@@ -11,10 +11,7 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import (
|
||||
api_key_item_model,
|
||||
api_key_list_model,
|
||||
)
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
|
||||
from controllers.console.wraps import (
|
||||
@@ -785,23 +782,23 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get dataset API keys")
|
||||
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_item_model)
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@@ -828,7 +825,7 @@ class DatasetApiKeyApi(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token, 200
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")
|
||||
|
||||
@@ -4,7 +4,6 @@ from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
@@ -16,6 +15,7 @@ from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
@@ -71,9 +71,6 @@ from ..wraps import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NOTE: Keep constants near the top of the module for discoverability.
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
@@ -110,12 +107,6 @@ class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
|
||||
@@ -10,6 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@@ -82,14 +83,6 @@ class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
@@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
|
||||
@@ -94,10 +94,9 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@@ -116,7 +115,4 @@ def plugin_data[**P, R](
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
return decorator
|
||||
|
||||
@@ -12,7 +12,12 @@ from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import Annotation, AnnotationList
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.annotation_service import (
|
||||
AppAnnotationService,
|
||||
EnableAnnotationArgs,
|
||||
InsertAnnotationArgs,
|
||||
UpdateAnnotationArgs,
|
||||
)
|
||||
|
||||
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
@@ -46,10 +51,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
|
||||
match action:
|
||||
case "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
enable_args: EnableAnnotationArgs = {
|
||||
"score_threshold": payload.score_threshold,
|
||||
"embedding_provider_name": payload.embedding_provider_name,
|
||||
"embedding_model_name": payload.embedding_model_name,
|
||||
}
|
||||
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
|
||||
case "disable":
|
||||
result = AppAnnotationService.disable_app_annotation(app_model.id)
|
||||
return result, 200
|
||||
@@ -135,8 +145,9 @@ class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json"), HTTPStatus.CREATED
|
||||
|
||||
@@ -164,8 +175,9 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@@ -86,13 +86,6 @@ class AudioApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy import desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@@ -100,15 +101,6 @@ class DocumentListQuery(BaseModel):
|
||||
status: str | None = Field(default=None, description="Document status filter")
|
||||
|
||||
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
|
||||
|
||||
register_enum_models(service_api_ns, RetrievalMethod)
|
||||
|
||||
register_schema_models(
|
||||
|
||||
@@ -2,9 +2,9 @@ from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
@@ -18,11 +18,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
@@ -69,20 +70,12 @@ class SegmentUpdatePayload(BaseModel):
|
||||
segment: SegmentUpdateArgs
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
|
||||
@@ -92,7 +92,7 @@ class HumanInputFormApi(Resource):
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# TODO(QuantumGhost): forbid submission for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@@ -25,7 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@@ -41,19 +40,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: str = Field(description="Conversation UUID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@@ -59,6 +59,24 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
@staticmethod
|
||||
def _get_completion_start_time(
|
||||
start_time: datetime | None, time_to_first_token: float | int | None
|
||||
) -> datetime | None:
|
||||
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
|
||||
if start_time is None or time_to_first_token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
ttft_seconds = float(time_to_first_token)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if ttft_seconds < 0:
|
||||
return None
|
||||
|
||||
return start_time + timedelta(seconds=ttft_seconds)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@@ -189,10 +207,18 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
completion_start_time = None
|
||||
try:
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = process_data.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = outputs.get("usage")
|
||||
if not isinstance(usage_data, dict):
|
||||
usage_data = {}
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
created_at, usage_data.get("time_to_first_token")
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
@@ -210,6 +236,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
trace_id=trace_id,
|
||||
model=process_data.get("model_name"),
|
||||
start_time=created_at,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=finished_at,
|
||||
input=inputs,
|
||||
output=outputs,
|
||||
@@ -290,11 +317,16 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
unit=UnitEnum.TOKENS,
|
||||
totalCost=message_data.total_price,
|
||||
)
|
||||
completion_start_time = self._get_completion_start_time(
|
||||
trace_info.start_time,
|
||||
trace_info.gen_ai_server_time_to_first_token,
|
||||
)
|
||||
|
||||
langfuse_generation_data = LangfuseGeneration(
|
||||
name="llm",
|
||||
trace_id=trace_id,
|
||||
start_time=trace_info.start_time,
|
||||
completion_start_time=completion_start_time,
|
||||
end_time=trace_info.end_time,
|
||||
model=message_data.model_id,
|
||||
input=trace_info.inputs,
|
||||
|
||||
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
87
api/core/rag/datasource/vdb/vector_backend_registry.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Vector store backend discovery.
|
||||
|
||||
Backends live in workspace packages under ``api/packages/dify-vdb-*/src/dify_vdb_*``. Each package
|
||||
declares third-party dependencies and registers ``importlib`` entry points in group
|
||||
``dify.vector_backends`` (see each package's ``pyproject.toml``).
|
||||
|
||||
Shared types and the :class:`~core.rag.datasource.vdb.vector_factory.AbstractVectorFactory` protocol
|
||||
remain in this package (``vector_base``, ``vector_factory``, ``vector_type``, ``field``).
|
||||
|
||||
Optional **built-in** targets in ``_BUILTIN_VECTOR_FACTORY_TARGETS`` (normally empty) load without a
|
||||
distribution; entry points take precedence when both exist.
|
||||
|
||||
After changing packages, run ``uv sync`` so installed dist-info entry points match ``pyproject.toml``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from importlib.metadata import entry_points
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VECTOR_FACTORY_CACHE: dict[str, type[AbstractVectorFactory]] = {}
|
||||
|
||||
# module_path:class_name — optional fallback when no distribution registers the backend.
|
||||
_BUILTIN_VECTOR_FACTORY_TARGETS: dict[str, str] = {}
|
||||
|
||||
|
||||
def clear_vector_factory_cache() -> None:
|
||||
"""Drop lazily loaded factories (for tests or plugin reload)."""
|
||||
_VECTOR_FACTORY_CACHE.clear()
|
||||
|
||||
|
||||
def _vector_backend_entry_points():
|
||||
return entry_points().select(group="dify.vector_backends")
|
||||
|
||||
|
||||
def _load_plugin_factory(vector_type: str) -> type[AbstractVectorFactory] | None:
|
||||
for ep in _vector_backend_entry_points():
|
||||
if ep.name != vector_type:
|
||||
continue
|
||||
try:
|
||||
loaded = ep.load()
|
||||
except Exception:
|
||||
logger.exception("Failed to load vector backend entry point %s", ep.name)
|
||||
raise
|
||||
return loaded # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
|
||||
def _unsupported(vector_type: str) -> ValueError:
|
||||
installed = sorted(ep.name for ep in _vector_backend_entry_points())
|
||||
available_msg = f" Installed backends: {', '.join(installed)}." if installed else " No backends installed."
|
||||
return ValueError(
|
||||
f"Vector store {vector_type!r} is not supported.{available_msg} "
|
||||
"Install a plugin (uv sync --group vdb-all, or vdb-<backend> per api/pyproject.toml), "
|
||||
"or register a dify.vector_backends entry point."
|
||||
)
|
||||
|
||||
|
||||
def _load_builtin_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
target = _BUILTIN_VECTOR_FACTORY_TARGETS.get(vector_type)
|
||||
if not target:
|
||||
raise _unsupported(vector_type)
|
||||
module_path, _, attr = target.partition(":")
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def get_vector_factory_class(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
"""Resolve :class:`AbstractVectorFactory` for a :class:`~VectorType` string value."""
|
||||
if vector_type in _VECTOR_FACTORY_CACHE:
|
||||
return _VECTOR_FACTORY_CACHE[vector_type]
|
||||
|
||||
plugin_cls = _load_plugin_factory(vector_type)
|
||||
if plugin_cls is not None:
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = plugin_cls
|
||||
return plugin_cls
|
||||
|
||||
cls = _load_builtin_factory(vector_type)
|
||||
_VECTOR_FACTORY_CACHE[vector_type] = cls
|
||||
return cls
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.rag.datasource.vdb.vector_backend_registry import get_vector_factory_class
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.cached_embedding import CacheEmbedding
|
||||
@@ -41,7 +42,23 @@ class AbstractVectorFactory(ABC):
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
|
||||
# `is_summary` and `original_chunk_id` are stored on summary vectors
|
||||
# by `SummaryIndexService` and read back by `RetrievalService` to
|
||||
# route summary hits through their original parent chunks. They
|
||||
# must be listed here so vector backends that use this list as an
|
||||
# explicit return-properties projection (notably Weaviate) actually
|
||||
# return those fields; without them, summary hits silently
|
||||
# collapse into `is_summary = False` branches and the summary
|
||||
# retrieval path is a no-op. See #34884.
|
||||
attributes = [
|
||||
"doc_id",
|
||||
"dataset_id",
|
||||
"document_id",
|
||||
"doc_hash",
|
||||
"doc_type",
|
||||
"is_summary",
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
@@ -69,137 +86,7 @@ class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
match vector_type:
|
||||
case VectorType.CHROMA:
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaVectorFactory
|
||||
|
||||
return ChromaVectorFactory
|
||||
case VectorType.MILVUS:
|
||||
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusVectorFactory
|
||||
|
||||
return MilvusVectorFactory
|
||||
case VectorType.ALIBABACLOUD_MYSQL:
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVectorFactory,
|
||||
)
|
||||
|
||||
return AlibabaCloudMySQLVectorFactory
|
||||
case VectorType.MYSCALE:
|
||||
from core.rag.datasource.vdb.myscale.myscale_vector import MyScaleVectorFactory
|
||||
|
||||
return MyScaleVectorFactory
|
||||
case VectorType.PGVECTOR:
|
||||
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
|
||||
|
||||
return PGVectorFactory
|
||||
case VectorType.VASTBASE:
|
||||
from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
|
||||
|
||||
return VastbaseVectorFactory
|
||||
case VectorType.PGVECTO_RS:
|
||||
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
|
||||
|
||||
return PGVectoRSFactory
|
||||
case VectorType.QDRANT:
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantVectorFactory
|
||||
|
||||
return QdrantVectorFactory
|
||||
case VectorType.RELYT:
|
||||
from core.rag.datasource.vdb.relyt.relyt_vector import RelytVectorFactory
|
||||
|
||||
return RelytVectorFactory
|
||||
case VectorType.ELASTICSEARCH:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
return ElasticSearchVectorFactory
|
||||
case VectorType.ELASTICSEARCH_JA:
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector import (
|
||||
ElasticSearchJaVectorFactory,
|
||||
)
|
||||
|
||||
return ElasticSearchJaVectorFactory
|
||||
case VectorType.TIDB_VECTOR:
|
||||
from core.rag.datasource.vdb.tidb_vector.tidb_vector import TiDBVectorFactory
|
||||
|
||||
return TiDBVectorFactory
|
||||
case VectorType.WEAVIATE:
|
||||
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateVectorFactory
|
||||
|
||||
return WeaviateVectorFactory
|
||||
case VectorType.TENCENT:
|
||||
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
|
||||
|
||||
return TencentVectorFactory
|
||||
case VectorType.ORACLE:
|
||||
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
|
||||
|
||||
return OracleVectorFactory
|
||||
case VectorType.OPENSEARCH:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
return CouchbaseVectorFactory
|
||||
case VectorType.BAIDU:
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory
|
||||
|
||||
return BaiduVectorFactory
|
||||
case VectorType.VIKINGDB:
|
||||
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory
|
||||
|
||||
return VikingDBVectorFactory
|
||||
case VectorType.UPSTASH:
|
||||
from core.rag.datasource.vdb.upstash.upstash_vector import UpstashVectorFactory
|
||||
|
||||
return UpstashVectorFactory
|
||||
case VectorType.TIDB_ON_QDRANT:
|
||||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE | VectorType.SEEKDB:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
return OceanBaseVectorFactory
|
||||
case VectorType.OPENGAUSS:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case VectorType.TABLESTORE:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
return HuaweiCloudVectorFactory
|
||||
case VectorType.MATRIXONE:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
return IrisVectorFactory
|
||||
case VectorType.HOLOGRES:
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
|
||||
|
||||
return HologresVectorFactory
|
||||
case _:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
"""Shared helpers for vector DB integration tests (used by workspace packages under ``api/packages``).
|
||||
|
||||
:class:`AbstractVectorTest` and helper functions live here so package tests can import
|
||||
``core.rag.datasource.vdb.vector_integration_test_support`` without relying on the
|
||||
``tests.*`` package.
|
||||
|
||||
The ``setup_mock_redis`` fixture lives in ``api/packages/conftest.py`` and is
|
||||
auto-discovered by pytest for all package tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.models.document import Document
|
||||
from extensions import ext_redis
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
@@ -25,24 +34,10 @@ def get_example_document(doc_id: str) -> Document:
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
# set
|
||||
ext_redis.redis_client.set = MagicMock(return_value=None)
|
||||
|
||||
# lock
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
ext_redis.redis_client.lock = mock_redis_lock
|
||||
|
||||
|
||||
class AbstractVectorTest:
|
||||
vector: BaseVector
|
||||
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
@@ -244,7 +244,7 @@ class DatasetDocumentStore:
|
||||
return document_segment
|
||||
|
||||
def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None):
|
||||
if multimodel_documents:
|
||||
if multimodel_documents and self._document_id is not None:
|
||||
for multimodel_document in multimodel_documents:
|
||||
binding = SegmentAttachmentBinding(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
|
||||
@@ -4,7 +4,12 @@ from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEve
|
||||
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||
from core.rag.entities.retrieval_settings import (
|
||||
KeywordSetting,
|
||||
RerankingModelConfig,
|
||||
VectorSetting,
|
||||
WeightedScoreConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Condition",
|
||||
@@ -19,6 +24,7 @@ __all__ = [
|
||||
"MetadataFilteringCondition",
|
||||
"ParentMode",
|
||||
"PreProcessingRule",
|
||||
"RerankingModelConfig",
|
||||
"RetrievalSourceMetadata",
|
||||
"Rule",
|
||||
"Segmentation",
|
||||
|
||||
@@ -1,4 +1,27 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Canonical reranking model configuration.
|
||||
|
||||
Accepts both naming conventions:
|
||||
- reranking_provider_name / reranking_model_name (services layer)
|
||||
- provider / model (workflow layer via validation_alias)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
reranking_provider_name: str = Field(validation_alias="provider")
|
||||
reranking_model_name: str = Field(validation_alias="model")
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self.reranking_provider_name
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self.reranking_model_name
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
|
||||
@@ -4,21 +4,12 @@ from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import NodeType
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.entities.retrieval_settings import WeightedScoreConfig
|
||||
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: str
|
||||
reranking_model_name: str
|
||||
|
||||
|
||||
class RetrievalSetting(BaseModel):
|
||||
"""
|
||||
Retrieval Setting.
|
||||
|
||||
@@ -5,20 +5,11 @@ from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.entities import Condition, MetadataFilteringCondition, WeightedScoreConfig
|
||||
from core.rag.entities import Condition, MetadataFilteringCondition, RerankingModelConfig, WeightedScoreConfig
|
||||
|
||||
__all__ = ["Condition"]
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
"""
|
||||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""
|
||||
Multiple Retrieval Config.
|
||||
|
||||
@@ -17,7 +17,6 @@ def http_status_message(code):
|
||||
|
||||
|
||||
def register_external_error_handlers(api: Api):
|
||||
@api.errorhandler(HTTPException)
|
||||
def handle_http_exception(e: HTTPException):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@@ -74,27 +73,18 @@ def register_external_error_handlers(api: Api):
|
||||
headers["Set-Cookie"] = build_force_logout_cookie_headers()
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@api.errorhandler(ValueError)
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_value_error
|
||||
|
||||
@api.errorhandler(AppInvokeQuotaExceededError)
|
||||
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
status_code = 429
|
||||
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
_ = handle_quota_exceeded
|
||||
|
||||
@api.errorhandler(Exception)
|
||||
def handle_general_exception(e: Exception):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
|
||||
@@ -113,7 +103,10 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
return data, status_code
|
||||
|
||||
_ = handle_general_exception
|
||||
api.errorhandler(HTTPException)(handle_http_exception)
|
||||
api.errorhandler(ValueError)(handle_value_error)
|
||||
api.errorhandler(AppInvokeQuotaExceededError)(handle_quota_exceeded)
|
||||
api.errorhandler(Exception)(handle_general_exception)
|
||||
|
||||
|
||||
class ExternalApi(Api):
|
||||
|
||||
@@ -1688,7 +1688,7 @@ class PipelineRecommendedPlugin(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
class SegmentAttachmentBinding(Base):
|
||||
class SegmentAttachmentBinding(TypeBase):
|
||||
__tablename__ = "segment_attachment_bindings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"),
|
||||
@@ -1701,13 +1701,17 @@ class SegmentAttachmentBinding(Base):
|
||||
),
|
||||
sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"),
|
||||
)
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
|
||||
@@ -838,7 +838,7 @@ class AppModelConfig(TypeBase):
|
||||
return self
|
||||
|
||||
|
||||
class RecommendedApp(Base): # bug
|
||||
class RecommendedApp(TypeBase):
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
|
||||
@@ -846,20 +846,37 @@ class RecommendedApp(Base): # bug
|
||||
sa.Index("recommended_app_is_listed_idx", "is_listed", "language"),
|
||||
)
|
||||
|
||||
id = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()))
|
||||
app_id = mapped_column(StringUUID, nullable=False)
|
||||
description = mapped_column(sa.JSON, nullable=False)
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
insert_default=lambda: str(uuid4()),
|
||||
default_factory=lambda: str(uuid4()),
|
||||
init=False,
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
description: Mapped[Any] = mapped_column(sa.JSON, nullable=False)
|
||||
copyright: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
privacy_policy: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
|
||||
category: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
custom_disclaimer: Mapped[str] = mapped_column(LongText, default="")
|
||||
position: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
is_listed: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True)
|
||||
install_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
|
||||
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'"))
|
||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
language: Mapped[str] = mapped_column(
|
||||
String(255),
|
||||
nullable=False,
|
||||
server_default=sa.text("'en-US'"),
|
||||
default="en-US",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1061,7 +1078,7 @@ class Conversation(Base):
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
||||
lambda: MessageAnnotation, backref="conversation", lazy="select", passive_deletes="all"
|
||||
)
|
||||
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
@@ -1820,7 +1837,7 @@ class MessageFile(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
class MessageAnnotation(Base):
|
||||
class MessageAnnotation(TypeBase):
|
||||
__tablename__ = "message_annotations"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
|
||||
@@ -1829,17 +1846,25 @@ class MessageAnnotation(Base):
|
||||
sa.Index("message_annotation_message_idx", "message_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID)
|
||||
question: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
content: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), default=None)
|
||||
message_id: Mapped[str | None] = mapped_column(StringUUID, default=None)
|
||||
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"), default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=func.current_timestamp(),
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -61,7 +61,7 @@ from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
from .account import Account
|
||||
from .base import Base, DefaultFieldsMixin, TypeBase
|
||||
from .base import Base, DefaultFieldsDCMixin, TypeBase
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
@@ -671,6 +671,29 @@ class Workflow(Base): # bug
|
||||
return str(d)
|
||||
|
||||
|
||||
class WorkflowRunDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
type: WorkflowType
|
||||
triggered_from: WorkflowRunTriggeredFrom
|
||||
version: str
|
||||
graph: Mapping[str, Any]
|
||||
inputs: Mapping[str, Any]
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any]
|
||||
error: str | None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_by_role: CreatorUserRole
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
exceptions_count: int
|
||||
|
||||
|
||||
class WorkflowRun(Base):
|
||||
"""
|
||||
Workflow Run
|
||||
@@ -742,8 +765,8 @@ class WorkflowRun(Base):
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
lambda: WorkflowPause,
|
||||
primaryjoin=lambda: WorkflowRun.id == orm.foreign(WorkflowPause.workflow_run_id),
|
||||
uselist=False,
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
@@ -790,29 +813,29 @@ class WorkflowRun(Base):
|
||||
def workflow(self):
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type,
|
||||
"triggered_from": self.triggered_from,
|
||||
"version": self.version,
|
||||
"graph": self.graph_dict,
|
||||
"inputs": self.inputs_dict,
|
||||
"status": self.status,
|
||||
"outputs": self.outputs_dict,
|
||||
"error": self.error,
|
||||
"elapsed_time": self.elapsed_time,
|
||||
"total_tokens": self.total_tokens,
|
||||
"total_steps": self.total_steps,
|
||||
"created_by_role": self.created_by_role,
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
"finished_at": self.finished_at,
|
||||
"exceptions_count": self.exceptions_count,
|
||||
}
|
||||
def to_dict(self) -> WorkflowRunDict:
|
||||
return WorkflowRunDict(
|
||||
id=self.id,
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_id=self.workflow_id,
|
||||
type=self.type,
|
||||
triggered_from=self.triggered_from,
|
||||
version=self.version,
|
||||
graph=self.graph_dict,
|
||||
inputs=self.inputs_dict,
|
||||
status=self.status,
|
||||
outputs=self.outputs_dict,
|
||||
error=self.error,
|
||||
elapsed_time=self.elapsed_time,
|
||||
total_tokens=self.total_tokens,
|
||||
total_steps=self.total_steps,
|
||||
created_by_role=self.created_by_role,
|
||||
created_by=self.created_by,
|
||||
created_at=self.created_at,
|
||||
finished_at=self.finished_at,
|
||||
exceptions_count=self.exceptions_count,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
@@ -1196,6 +1219,18 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
raise ValueError(f"invalid workflow app log created from value {value}")
|
||||
|
||||
|
||||
class WorkflowAppLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
created_from: WorkflowAppLogCreatedFrom
|
||||
created_by_role: CreatorUserRole
|
||||
created_by: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class WorkflowAppLog(TypeBase):
|
||||
"""
|
||||
Workflow App execution log, excluding workflow debugging records.
|
||||
@@ -1273,8 +1308,8 @@ class WorkflowAppLog(TypeBase):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> WorkflowAppLogDict:
|
||||
result: WorkflowAppLogDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"app_id": self.app_id,
|
||||
@@ -1285,6 +1320,7 @@ class WorkflowAppLog(TypeBase):
|
||||
"created_by": self.created_by,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
class WorkflowArchiveLog(TypeBase):
|
||||
@@ -1941,7 +1977,7 @@ def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
|
||||
class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
class WorkflowPause(DefaultFieldsDCMixin, TypeBase):
|
||||
"""
|
||||
WorkflowPause records the paused state and related metadata for a specific workflow run.
|
||||
|
||||
@@ -1980,6 +2016,11 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# `resumed_at` records the timestamp when the suspended workflow was resumed.
|
||||
# It is set to `NULL` if the workflow has not been resumed.
|
||||
#
|
||||
@@ -1988,25 +2029,23 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
resumed_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
# state_object_key stores the object key referencing the serialized runtime state
|
||||
# of the `GraphEngine`. This object captures the complete execution context of the
|
||||
# workflow at the moment it was paused, enabling accurate resumption.
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
# Relationship to WorkflowRun (uses lambda to resolve across Base/TypeBase registries)
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
lambda: WorkflowRun,
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPause.workflow_run_id == WorkflowRun.id",
|
||||
primaryjoin=lambda: WorkflowPause.workflow_run_id == WorkflowRun.id,
|
||||
back_populates="pause",
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase):
|
||||
__tablename__ = "workflow_pause_reasons"
|
||||
|
||||
# `pause_id` represents the identifier of the pause,
|
||||
@@ -2049,16 +2088,20 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowPauseReason.pause_id == WorkflowPause.id",
|
||||
init=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
pause_id=pause_id,
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id=pause_reason.form_id,
|
||||
node_id=pause_reason.node_id,
|
||||
)
|
||||
elif isinstance(pause_reason, SchedulingPause):
|
||||
return cls(type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message, node_id="")
|
||||
return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message)
|
||||
else:
|
||||
raise AssertionError(f"Unknown pause reason type: {pause_reason}")
|
||||
|
||||
|
||||
12
api/providers/README.md
Normal file
12
api/providers/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Providers
|
||||
|
||||
This directory holds **optional workspace packages** that plug into Dify’s API core. Providers are responsible for implementing the interfaces and registering themselves to the API core. Provider mechanism allows building the software with selected set of providers so as to enhance the security and flexibility of distributions.
|
||||
|
||||
## Developing Providers
|
||||
|
||||
- [VDB Providers](vdb/README.md)
|
||||
|
||||
## Tests
|
||||
|
||||
Provider tests often live next to the package, e.g. `providers/<type>/<backend>/tests/unit_tests/`. Shared fixtures may live under `providers/` (e.g. `conftest.py`).
|
||||
|
||||
58
api/providers/vdb/README.md
Normal file
58
api/providers/vdb/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# VDB providers
|
||||
|
||||
This directory contains all VDB providers.
|
||||
|
||||
## Architecture
|
||||
1. **Core** (`api/core/rag/datasource/vdb/`) defines the contracts and loads plugins.
|
||||
2. **Each provider** (`api/providers/vdb/<backend>/`) implements those contracts and registers an entry point.
|
||||
3. At runtime, **`importlib.metadata.entry_points`** resolves the backend name (e.g. `pgvector`) to a factory class. The registry caches loaded classes (see `vector_backend_registry.py`).
|
||||
|
||||
### Interfaces
|
||||
|
||||
| Piece | Role |
|
||||
|--------|----------|
|
||||
| `AbstractVectorFactory` | You subclass this. Implement `init_vector(dataset, attributes, embeddings) -> BaseVector`. Optionally use `gen_index_struct_dict()` for new datasets. |
|
||||
| `BaseVector` | Your store class subclasses this: `create`, `add_texts`, `search_by_vector`, `delete`, etc. |
|
||||
| `VectorType` | `StrEnum` of supported backend **string ids**. Add a member when you introduce a new backend that should be selectable like existing ones. |
|
||||
| Discovery | Loads `dify.vector_backends` entry points and caches `get_vector_factory_class(vector_type)`. |
|
||||
|
||||
The high-level caller is `Vector` in `vector_factory.py`: it reads the configured or dataset-specific vector type, calls `get_vector_factory_class`, instantiates the factory, and uses the returned `BaseVector` implementation.
|
||||
|
||||
### Entry point name must match the vector type string
|
||||
|
||||
Entry points are registered under the group **`dify.vector_backends`**. The **entry point name** (left-hand side) must be exactly the string used as `vector_type` everywhere else—typically the **`VectorType` enum value** (e.g. `PGVECTOR = "pgvector"` → entry point name `pgvector`; `TIDB_ON_QDRANT = "tidb_on_qdrant"` → `tidb_on_qdrant`).
|
||||
|
||||
In `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
pgvector = "dify_vdb_pgvector.pgvector:PGVectorFactory"
|
||||
```
|
||||
|
||||
The value is **`module:attribute`**: a importable module path and the class implementing `AbstractVectorFactory`.
|
||||
|
||||
### How registration works
|
||||
|
||||
1. On first use, `get_vector_factory_class(vector_type)` looks up `vector_type` in a process cache.
|
||||
2. If missing, it scans **`entry_points().select(group="dify.vector_backends")`** for an entry whose **`name` equals `vector_type`**.
|
||||
3. It loads that entry (`ep.load()`), which must return the **factory class** (not an instance).
|
||||
4. There is an optional internal map `_BUILTIN_VECTOR_FACTORY_TARGETS` for non-distribution builtins; **normal VDB plugins use entry points only**.
|
||||
|
||||
After you change a provider’s `pyproject.toml` (entry points or dependencies), run **`uv sync`** in `api/` so the installed environment’s dist-info matches the project metadata.
|
||||
|
||||
### Package layout (VDB)
|
||||
|
||||
Each backend usually follows:
|
||||
|
||||
- `api/providers/vdb/<backend>/pyproject.toml` — project name `dify-vdb-<backend>`, dependencies, entry points.
|
||||
- `api/providers/vdb/<backend>/src/dify_vdb_<python_package>/` — implementation (e.g. `PGVector`, `PGVectorFactory`).
|
||||
|
||||
See `vdb/pgvector/` as a reference implementation.
|
||||
|
||||
### Wiring a new backend into the API workspace
|
||||
|
||||
The API uses a **uv workspace** (`api/pyproject.toml`):
|
||||
|
||||
1. **`[tool.uv.workspace]`** — `members = ["providers/vdb/*"]` already includes every subdirectory under `vdb/`; new folders there are workspace members.
|
||||
2. **`[tool.uv.sources]`** — add a line for your package: `dify-vdb-mine = { workspace = true }`.
|
||||
3. **`[project.optional-dependencies]`** — add a group such as `vdb-mine = ["dify-vdb-mine"]`, and list `dify-vdb-mine` under `vdb-all` if it should install with the default bundle.
|
||||
22
api/providers/vdb/conftest.py
Normal file
22
api/providers/vdb/conftest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _init_mock_redis():
|
||||
"""Ensure redis_client has a backing client so __getattr__ never raises."""
|
||||
if ext_redis.redis_client._client is None:
|
||||
ext_redis.redis_client.initialize(MagicMock())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(ext_redis.redis_client, "get", MagicMock(return_value=None))
|
||||
monkeypatch.setattr(ext_redis.redis_client, "set", MagicMock(return_value=None))
|
||||
mock_redis_lock = MagicMock()
|
||||
mock_redis_lock.__enter__ = MagicMock()
|
||||
mock_redis_lock.__exit__ = MagicMock()
|
||||
monkeypatch.setattr(ext_redis.redis_client, "lock", mock_redis_lock)
|
||||
13
api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml
Normal file
13
api/providers/vdb/vdb-alibabacloud-mysql/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-alibabacloud-mysql"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"mysql-connector-python>=9.3.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-alibabacloud-mysql)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
alibabacloud_mysql = "dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector:AlibabaCloudMySQLVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,9 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector as alibaba_module
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
|
||||
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import AlibabaCloudMySQLVectorFactory
|
||||
|
||||
|
||||
def test_validate_distance_function_accepts_supported_values():
|
||||
@@ -3,11 +3,11 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
from dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector import (
|
||||
AlibabaCloudMySQLVector,
|
||||
AlibabaCloudMySQLVectorConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
try:
|
||||
@@ -49,9 +49,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
# Sample embeddings
|
||||
self.sample_embeddings = [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_init(self, mock_pool_class):
|
||||
"""Test AlibabaCloudMySQLVector initialization."""
|
||||
# Mock the connection pool
|
||||
@@ -76,10 +74,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert alibabacloud_mysql_vector.distance_function == "cosine"
|
||||
assert alibabacloud_mysql_vector.pool is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_collection(self, mock_redis, mock_pool_class):
|
||||
"""Test collection creation."""
|
||||
# Mock Redis operations
|
||||
@@ -110,9 +106,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert mock_cursor.execute.call_count >= 3 # CREATE TABLE + 2 indexes
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_success(self, mock_pool_class):
|
||||
"""Test successful vector support check."""
|
||||
# Mock the connection pool
|
||||
@@ -129,9 +123,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
vector_store = AlibabaCloudMySQLVector(self.collection_name, self.config)
|
||||
assert vector_store is not None
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_failure(self, mock_pool_class):
|
||||
"""Test vector support check failure."""
|
||||
# Mock the connection pool
|
||||
@@ -149,9 +141,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_vector_support_check_function_error(self, mock_pool_class):
|
||||
"""Test vector support check with function not found error."""
|
||||
# Mock the connection pool
|
||||
@@ -170,10 +160,8 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert "RDS MySQL Vector functions are not available" in str(context.value)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.redis_client")
|
||||
def test_create_documents(self, mock_redis, mock_pool_class):
|
||||
"""Test creating documents with embeddings."""
|
||||
# Setup mocks
|
||||
@@ -186,9 +174,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_add_texts(self, mock_pool_class):
|
||||
"""Test adding texts to the vector store."""
|
||||
# Mock the connection pool
|
||||
@@ -207,9 +193,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(result) == 2
|
||||
mock_cursor.executemany.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_text_exists(self, mock_pool_class):
|
||||
"""Test checking if text exists."""
|
||||
# Mock the connection pool
|
||||
@@ -236,9 +220,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "SELECT id FROM" in last_call[0][0]
|
||||
assert last_call[0][1] == ("doc1",)
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_text_not_exists(self, mock_pool_class):
|
||||
"""Test checking if text does not exist."""
|
||||
# Mock the connection pool
|
||||
@@ -260,9 +242,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert not exists
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_get_by_ids(self, mock_pool_class):
|
||||
"""Test getting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
@@ -288,9 +268,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert docs[0].page_content == "Test document 1"
|
||||
assert docs[1].page_content == "Test document 2"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_get_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test getting documents with empty ID list."""
|
||||
# Mock the connection pool
|
||||
@@ -308,9 +286,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
|
||||
assert len(docs) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids(self, mock_pool_class):
|
||||
"""Test deleting documents by IDs."""
|
||||
# Mock the connection pool
|
||||
@@ -334,9 +310,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "DELETE FROM" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ["doc1", "doc2"]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids_empty_list(self, mock_pool_class):
|
||||
"""Test deleting with empty ID list."""
|
||||
# Mock the connection pool
|
||||
@@ -357,9 +331,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
delete_calls = [call for call in execute_calls if "DELETE" in str(call)]
|
||||
assert len(delete_calls) == 0
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_ids_table_not_exists(self, mock_pool_class):
|
||||
"""Test deleting when table doesn't exist."""
|
||||
# Mock the connection pool
|
||||
@@ -384,9 +356,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
# Should not raise an exception
|
||||
vector_store.delete_by_ids(["doc1"])
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_by_metadata_field(self, mock_pool_class):
|
||||
"""Test deleting documents by metadata field."""
|
||||
# Mock the connection pool
|
||||
@@ -410,9 +380,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert "JSON_UNQUOTE(JSON_EXTRACT(meta" in delete_call[0][0]
|
||||
assert delete_call[0][1] == ("$.document_id", "dataset1")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_cosine(self, mock_pool_class):
|
||||
"""Test vector search with cosine distance."""
|
||||
# Mock the connection pool
|
||||
@@ -437,9 +405,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert abs(docs[0].metadata["score"] - 0.9) < 0.1 # 1 - 0.1 = 0.9
|
||||
assert docs[0].metadata["distance"] == 0.1
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_euclidean(self, mock_pool_class):
|
||||
"""Test vector search with euclidean distance."""
|
||||
config = AlibabaCloudMySQLVectorConfig(
|
||||
@@ -472,9 +438,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert abs(docs[0].metadata["score"] - 1.0 / 3.0) < 0.01 # 1/(1+2) = 1/3
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_with_filter(self, mock_pool_class):
|
||||
"""Test vector search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
@@ -499,9 +463,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
search_call = search_calls[0]
|
||||
assert "WHERE JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_with_score_threshold(self, mock_pool_class):
|
||||
"""Test vector search with score threshold."""
|
||||
# Mock the connection pool
|
||||
@@ -536,9 +498,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "High similarity document"
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_vector_invalid_top_k(self, mock_pool_class):
|
||||
"""Test vector search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
@@ -560,9 +520,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_vector(query_vector, top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text(self, mock_pool_class):
|
||||
"""Test full-text search."""
|
||||
# Mock the connection pool
|
||||
@@ -591,9 +549,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
assert docs[0].page_content == "This document contains machine learning content"
|
||||
assert docs[0].metadata["score"] == 1.5
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text_with_filter(self, mock_pool_class):
|
||||
"""Test full-text search with document ID filter."""
|
||||
# Mock the connection pool
|
||||
@@ -617,9 +573,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
search_call = search_calls[0]
|
||||
assert "AND JSON_UNQUOTE" in search_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_search_by_full_text_invalid_top_k(self, mock_pool_class):
|
||||
"""Test full-text search with invalid top_k."""
|
||||
# Mock the connection pool
|
||||
@@ -640,9 +594,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
with pytest.raises(ValueError):
|
||||
vector_store.search_by_full_text("test", top_k="invalid")
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_delete_collection(self, mock_pool_class):
|
||||
"""Test deleting the entire collection."""
|
||||
# Mock the connection pool
|
||||
@@ -665,9 +617,7 @@ class TestAlibabaCloudMySQLVector(unittest.TestCase):
|
||||
drop_call = drop_calls[0]
|
||||
assert f"DROP TABLE IF EXISTS {self.collection_name.lower()}" in drop_call[0][0]
|
||||
|
||||
@patch(
|
||||
"core.rag.datasource.vdb.alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool"
|
||||
)
|
||||
@patch("dify_vdb_alibabacloud_mysql.alibabacloud_mysql_vector.mysql.connector.pooling.MySQLConnectionPool")
|
||||
def test_unsupported_distance_function(self, mock_pool_class):
|
||||
"""Test that Pydantic validation rejects unsupported distance functions."""
|
||||
# Test that creating config with unsupported distance function raises ValidationError
|
||||
15
api/providers/vdb/vdb-analyticdb/pyproject.toml
Normal file
15
api/providers/vdb/vdb-analyticdb/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "dify-vdb-analyticdb"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"alibabacloud_gpdb20160503~=5.2.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"clickhouse-connect~=0.15.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-analyticdb)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
analyticdb = "dify_vdb_analyticdb.analyticdb_vector:AnalyticdbVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -2,16 +2,16 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest
|
||||
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
@@ -1,12 +1,12 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector as analyticdb_module
|
||||
import pytest
|
||||
from dify_vdb_analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector as analyticdb_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector, AnalyticdbVectorFactory
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector_openapi as openapi_module
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi as openapi_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
from dify_vdb_analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@ from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import dify_vdb_analyticdb.analyticdb_vector_sql as sql_module
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
|
||||
import core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql as sql_module
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import (
|
||||
from dify_vdb_analyticdb.analyticdb_vector_sql import (
|
||||
AnalyticdbVectorBySql,
|
||||
AnalyticdbVectorBySqlConfig,
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
13
api/providers/vdb/vdb-baidu/pyproject.toml
Normal file
13
api/providers/vdb/vdb-baidu/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-baidu"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"pymochow==2.4.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-baidu)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
baidu = "dify_vdb_baidu.baidu_vector:BaiduVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,10 +1,6 @@
|
||||
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
from dify_vdb_baidu.baidu_vector import BaiduConfig, BaiduVector
|
||||
|
||||
pytest_plugins = (
|
||||
"tests.integration_tests.vdb.test_vector_store",
|
||||
"tests.integration_tests.vdb.__mock.baiduvectordb",
|
||||
)
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
|
||||
|
||||
class BaiduVectorTest(AbstractVectorTest):
|
||||
@@ -124,7 +124,7 @@ def _build_fake_pymochow_modules():
|
||||
def baidu_module(monkeypatch):
|
||||
for name, module in _build_fake_pymochow_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
import core.rag.datasource.vdb.baidu.baidu_vector as module
|
||||
import dify_vdb_baidu.baidu_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
13
api/providers/vdb/vdb-chroma/pyproject.toml
Normal file
13
api/providers/vdb/vdb-chroma/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[project]
|
||||
name = "dify-vdb-chroma"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"chromadb==0.5.20",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-chroma)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
chroma = "dify_vdb_chroma.chroma_vector:ChromaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,13 +1,11 @@
|
||||
import chromadb
|
||||
from dify_vdb_chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
|
||||
from core.rag.datasource.vdb.chroma.chroma_vector import ChromaConfig, ChromaVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
get_example_text,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class ChromaVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -47,7 +47,7 @@ def _build_fake_chroma_modules():
|
||||
def chroma_module(monkeypatch):
|
||||
fake_chroma = _build_fake_chroma_modules()
|
||||
monkeypatch.setitem(sys.modules, "chromadb", fake_chroma)
|
||||
import core.rag.datasource.vdb.chroma.chroma_vector as module
|
||||
import dify_vdb_chroma.chroma_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
@@ -198,4 +198,4 @@ Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
14
api/providers/vdb/vdb-clickzetta/pyproject.toml
Normal file
14
api/providers/vdb/vdb-clickzetta/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-clickzetta"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-clickzetta)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
clickzetta = "dify_vdb_clickzetta.clickzetta_vector:ClickzettaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -2,10 +2,10 @@ import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from dify_vdb_clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
@@ -14,9 +14,8 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
def vector_store(self, setup_mock_redis):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
@@ -32,21 +31,19 @@ class TestClickzettaVector(AbstractVectorTest):
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
batch_size=10,
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
@@ -3,16 +3,19 @@
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
logger.info("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
@@ -25,100 +28,93 @@ def test_clickzetta_connection():
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
logger.info("✓ Connection test: %s", result)
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
logger.info("✓ Existing tables: %s", [t[1] for t in tables if t[0] == "dify"])
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
logger.info("✓ Table structure for %s:", test_collection)
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
logger.info(" - %s: %s", col[0], col[1])
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
logger.info("✓ Indexes on %s:", test_collection)
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
logger.info(" - %s", idx)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
except Exception:
|
||||
logger.exception("✗ Connection test failed")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
logger.info("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
logger.info("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
logger.exception("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
logger.info("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
except Exception:
|
||||
logger.exception("✗ API test failed")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
logger.info("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"metadata": "VARCHAR",
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
logger.info("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
logger.info(" - %s: %s", col, dtype)
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
logger.info("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
logger.info(" - %s", field)
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
logger.info("\n✓ Index requirements:")
|
||||
logger.info(" - Vector index (HNSW) on 'vector' column")
|
||||
logger.info(" - Full-text index on 'page_content' (optional)")
|
||||
logger.info(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
logger.info(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
logger.info("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
@@ -131,33 +127,34 @@ def main():
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
except Exception:
|
||||
logger.exception("\n✗ %s crashed", test_name)
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
logger.info("\n%s", "=" * 50)
|
||||
logger.info("Test Summary:")
|
||||
logger.info("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
logger.info("%s: %s", test_name, status)
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
logger.info("\nTotal: %s/%s tests passed", passed, total)
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
logger.info("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
logger.info("\nNext steps:")
|
||||
logger.info(
|
||||
"1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d"
|
||||
)
|
||||
logger.info("2. Access Dify at http://localhost:3000")
|
||||
logger.info("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
logger.error("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def _build_fake_clickzetta_module():
|
||||
@pytest.fixture
|
||||
def clickzetta_module(monkeypatch):
|
||||
monkeypatch.setitem(sys.modules, "clickzetta", _build_fake_clickzetta_module())
|
||||
import core.rag.datasource.vdb.clickzetta.clickzetta_vector as module
|
||||
import dify_vdb_clickzetta.clickzetta_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-couchbase/pyproject.toml
Normal file
14
api/providers/vdb/vdb-couchbase/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-couchbase"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"couchbase~=4.6.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-couchbase)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
couchbase = "dify_vdb_couchbase.couchbase_vector:CouchbaseVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_couchbase.couchbase_vector import CouchbaseConfig, CouchbaseVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
@@ -16,10 +18,10 @@ def wait_for_healthy_container(service_name="couchbase-server", timeout=300):
|
||||
["docker", "inspect", "--format", "{{.State.Health.Status}}", service_name], capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "healthy":
|
||||
print(f"{service_name} is healthy!")
|
||||
logger.info("%s is healthy!", service_name)
|
||||
return True
|
||||
else:
|
||||
print(f"Waiting for {service_name} to be healthy...")
|
||||
logger.info("Waiting for %s to be healthy...", service_name)
|
||||
time.sleep(10)
|
||||
raise TimeoutError(f"{service_name} did not become healthy in time")
|
||||
|
||||
@@ -154,7 +154,7 @@ def couchbase_module(monkeypatch):
|
||||
for name, module in _build_fake_couchbase_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.couchbase.couchbase_vector as module
|
||||
import dify_vdb_couchbase.couchbase_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
15
api/providers/vdb/vdb-elasticsearch/pyproject.toml
Normal file
15
api/providers/vdb/vdb-elasticsearch/pyproject.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[project]
|
||||
name = "dify-vdb-elasticsearch"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"elasticsearch==8.14.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-elasticsearch)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
elasticsearch = "dify_vdb_elasticsearch.elasticsearch_vector:ElasticSearchVectorFactory"
|
||||
elasticsearch-ja = "dify_vdb_elasticsearch.elasticsearch_ja_vector:ElasticSearchJaVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -4,14 +4,14 @@ from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import (
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from dify_vdb_elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchConfig,
|
||||
ElasticSearchVector,
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from core.rag.datasource.vdb.elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
from dify_vdb_elasticsearch.elasticsearch_vector import ElasticSearchConfig, ElasticSearchVector
|
||||
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import (
|
||||
AbstractVectorTest,
|
||||
)
|
||||
|
||||
pytest_plugins = ("tests.integration_tests.vdb.test_vector_store",)
|
||||
|
||||
|
||||
class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
@@ -32,8 +32,8 @@ def elasticsearch_ja_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_ja_vector as ja_module
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as base_module
|
||||
import dify_vdb_elasticsearch.elasticsearch_ja_vector as ja_module
|
||||
import dify_vdb_elasticsearch.elasticsearch_vector as base_module
|
||||
|
||||
importlib.reload(base_module)
|
||||
return importlib.reload(ja_module)
|
||||
@@ -42,7 +42,7 @@ def elasticsearch_module(monkeypatch):
|
||||
for name, module in _build_fake_elasticsearch_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.elasticsearch.elasticsearch_vector as module
|
||||
import dify_vdb_elasticsearch.elasticsearch_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-hologres/pyproject.toml
Normal file
14
api/providers/vdb/vdb-hologres/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-hologres"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"holo-search-sdk>=0.4.2",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-hologres)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
hologres = "dify_vdb_hologres.hologres_vector:HologresVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import holo_search_sdk as holo # type: ignore
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
@@ -351,9 +351,9 @@ class HologresVectorFactory(AbstractVectorFactory):
|
||||
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
|
||||
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
|
||||
schema_name=dify_config.HOLOGRES_SCHEMA,
|
||||
tokenizer=dify_config.HOLOGRES_TOKENIZER,
|
||||
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
|
||||
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
|
||||
tokenizer=cast(TokenizerType, dify_config.HOLOGRES_TOKENIZER),
|
||||
distance_method=cast(DistanceType, dify_config.HOLOGRES_DISTANCE_METHOD),
|
||||
base_quantization_type=cast(BaseQuantizationType, dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE),
|
||||
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
|
||||
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
|
||||
),
|
||||
@@ -7,13 +7,10 @@ import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from psycopg import sql as psql
|
||||
|
||||
# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}}
|
||||
_mock_tables: dict[str, dict[str, dict[str, Any]]] = {}
|
||||
|
||||
|
||||
class MockSearchQuery:
|
||||
"""Mock query builder for search_vector and search_text results."""
|
||||
|
||||
def __init__(self, table_name: str, search_type: str):
|
||||
self._table_name = table_name
|
||||
self._search_type = search_type
|
||||
@@ -32,17 +29,13 @@ class MockSearchQuery:
|
||||
return self
|
||||
|
||||
def _apply_filter(self, row: dict[str, Any]) -> bool:
|
||||
"""Apply the filter SQL to check if a row matches."""
|
||||
if self._filter_sql is None:
|
||||
return True
|
||||
|
||||
# Extract literals (the document IDs) from the filter SQL
|
||||
# Filter format: meta->>'document_id' IN ('doc1', 'doc2')
|
||||
literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"]
|
||||
if not literals:
|
||||
return True
|
||||
|
||||
# Get the document_id from the row's meta field
|
||||
meta = row.get("meta", "{}")
|
||||
if isinstance(meta, str):
|
||||
meta = json.loads(meta)
|
||||
@@ -54,22 +47,17 @@ class MockSearchQuery:
|
||||
data = _mock_tables.get(self._table_name, {})
|
||||
results = []
|
||||
for row in list(data.values())[: self._limit_val]:
|
||||
# Apply filter if present
|
||||
if not self._apply_filter(row):
|
||||
continue
|
||||
|
||||
if self._search_type == "vector":
|
||||
# row format expected by _process_vector_results: (distance, id, text, meta)
|
||||
results.append((0.1, row["id"], row["text"], row["meta"]))
|
||||
else:
|
||||
# row format expected by _process_full_text_results: (id, text, meta, embedding, score)
|
||||
results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9))
|
||||
return results
|
||||
|
||||
|
||||
class MockTable:
|
||||
"""Mock table object returned by client.open_table()."""
|
||||
|
||||
def __init__(self, table_name: str):
|
||||
self._table_name = table_name
|
||||
|
||||
@@ -97,7 +85,6 @@ class MockTable:
|
||||
|
||||
|
||||
def _extract_sql_template(query) -> str:
|
||||
"""Extract the SQL template string from a psycopg Composed object."""
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
if isinstance(part, psql.SQL):
|
||||
@@ -108,7 +95,6 @@ def _extract_sql_template(query) -> str:
|
||||
|
||||
|
||||
def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
"""Extract Identifier and Literal values from a psycopg Composed object."""
|
||||
values: list[Any] = []
|
||||
if isinstance(query, psql.Composed):
|
||||
for part in query:
|
||||
@@ -117,7 +103,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
elif isinstance(part, psql.Literal):
|
||||
values.append(("literal", part._obj))
|
||||
elif isinstance(part, psql.Composed):
|
||||
# Handles SQL(...).join(...) for IN clauses
|
||||
for sub in part:
|
||||
if isinstance(sub, psql.Literal):
|
||||
values.append(("literal", sub._obj))
|
||||
@@ -125,8 +110,6 @@ def _extract_identifiers_and_literals(query) -> list[Any]:
|
||||
|
||||
|
||||
class MockHologresClient:
|
||||
"""Mock holo_search_sdk client that stores data in memory."""
|
||||
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@@ -141,21 +124,18 @@ class MockHologresClient:
|
||||
params = _extract_identifiers_and_literals(query)
|
||||
|
||||
if "CREATE TABLE" in template.upper():
|
||||
# Extract table name from first identifier
|
||||
table_name = next((v for t, v in params if t == "ident"), "unknown")
|
||||
if table_name not in _mock_tables:
|
||||
_mock_tables[table_name] = {}
|
||||
return None
|
||||
|
||||
if "SELECT 1" in template:
|
||||
# text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
doc_id = next((v for t, v in params if t == "literal"), "")
|
||||
data = _mock_tables.get(table_name, {})
|
||||
return [(1,)] if doc_id in data else []
|
||||
|
||||
if "SELECT id" in template:
|
||||
# get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value}
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
literals = [v for t, v in params if t == "literal"]
|
||||
key = literals[0] if len(literals) > 0 else ""
|
||||
@@ -166,12 +146,10 @@ class MockHologresClient:
|
||||
if "DELETE" in template.upper():
|
||||
table_name = next((v for t, v in params if t == "ident"), "")
|
||||
if "id IN" in template:
|
||||
# delete_by_ids
|
||||
ids_to_delete = [v for t, v in params if t == "literal"]
|
||||
for did in ids_to_delete:
|
||||
_mock_tables.get(table_name, {}).pop(did, None)
|
||||
elif "meta->>" in template:
|
||||
# delete_by_metadata_field
|
||||
literals = [v for t, v in params if t == "literal"]
|
||||
key = literals[0] if len(literals) > 0 else ""
|
||||
value = literals[1] if len(literals) > 1 else ""
|
||||
@@ -190,7 +168,6 @@ class MockHologresClient:
|
||||
|
||||
|
||||
def mock_connect(**kwargs):
|
||||
"""Replacement for holo_search_sdk.connect() that returns a mock client."""
|
||||
return MockHologresClient()
|
||||
|
||||
|
||||
@@ -2,16 +2,11 @@ import os
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from dify_vdb_hologres.hologres_vector import HologresVector, HologresVectorConfig
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
|
||||
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig
|
||||
from core.rag.datasource.vdb.vector_integration_test_support import AbstractVectorTest, get_example_text
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text
|
||||
|
||||
pytest_plugins = (
|
||||
"tests.integration_tests.vdb.test_vector_store",
|
||||
"tests.integration_tests.vdb.__mock.hologres",
|
||||
)
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
@@ -42,7 +42,7 @@ def hologres_module(monkeypatch):
|
||||
for name, module in _build_fake_hologres_modules().items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
import core.rag.datasource.vdb.hologres.hologres_vector as module
|
||||
import dify_vdb_hologres.hologres_vector as module
|
||||
|
||||
return importlib.reload(module)
|
||||
|
||||
14
api/providers/vdb/vdb-huawei-cloud/pyproject.toml
Normal file
14
api/providers/vdb/vdb-huawei-cloud/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "dify-vdb-huawei-cloud"
|
||||
version = "0.0.1"
|
||||
|
||||
dependencies = [
|
||||
"elasticsearch==8.14.0",
|
||||
]
|
||||
description = "Dify vector store backend (dify-vdb-huawei-cloud)."
|
||||
|
||||
[project.entry-points."dify.vector_backends"]
|
||||
huawei_cloud = "dify_vdb_huawei_cloud.huawei_cloud_vector:HuaweiCloudVectorFactory"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user