mirror of
https://github.com/langgenius/dify.git
synced 2026-03-31 21:16:50 +00:00
Compare commits
17 Commits
codex/forc
...
fix/1.13.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
138d497418 | ||
|
|
5760914fb3 | ||
|
|
8f79989172 | ||
|
|
9d7ea953ea | ||
|
|
5b5b21502b | ||
|
|
44c356258f | ||
|
|
44fb3cd2af | ||
|
|
2bf6728951 | ||
|
|
fcfa11a71a | ||
|
|
1730f900c1 | ||
|
|
12178e7aec | ||
|
|
afe23a029b | ||
|
|
c8560bacb3 | ||
|
|
0f1b8bf5f9 | ||
|
|
652211ad96 | ||
|
|
c049249bc1 | ||
|
|
138083dfc8 |
@@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@@ -321,20 +321,6 @@ CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# AnalyticDB configuration
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# OpenSearch configuration
|
||||
OPENSEARCH_HOST=127.0.0.1
|
||||
OPENSEARCH_PORT=9200
|
||||
|
||||
@@ -608,7 +608,7 @@ def migrate_oss(
|
||||
click.style(
|
||||
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
|
||||
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs.",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -155,11 +155,9 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.MATRIXONE,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.HOLOGRES,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
|
||||
@@ -11,7 +11,6 @@ from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
|
||||
from .storage.clickzetta_volume_storage_config import ClickZettaVolumeStorageConfig
|
||||
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from .storage.oci_storage_config import OCIStorageConfig
|
||||
@@ -20,10 +19,8 @@ from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
from .vdb.clickzetta_config import ClickzettaConfig
|
||||
from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.hologres_config import HologresConfig
|
||||
@@ -41,7 +38,6 @@ from .vdb.pgvector_config import PGVectorConfig
|
||||
from .vdb.pgvectors_config import PGVectoRSConfig
|
||||
from .vdb.qdrant_config import QdrantConfig
|
||||
from .vdb.relyt_config import RelytConfig
|
||||
from .vdb.tablestore_config import TableStoreConfig
|
||||
from .vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
|
||||
from .vdb.tidb_vector_config import TiDBVectorConfig
|
||||
@@ -58,7 +54,6 @@ class StorageConfig(BaseSettings):
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
@@ -69,7 +64,7 @@ class StorageConfig(BaseSettings):
|
||||
] = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
)
|
||||
@@ -334,7 +329,6 @@ class MiddlewareConfig(
|
||||
AliyunOSSStorageConfig,
|
||||
AzureBlobStorageConfig,
|
||||
BaiduOBSStorageConfig,
|
||||
ClickZettaVolumeStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
@@ -345,9 +339,7 @@ class MiddlewareConfig(
|
||||
VolcengineTOSStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
VectorStoreConfig,
|
||||
AnalyticdbConfig,
|
||||
ChromaConfig,
|
||||
ClickzettaConfig,
|
||||
HologresConfig,
|
||||
HuaweiCloudConfig,
|
||||
IrisVectorConfig,
|
||||
@@ -374,7 +366,6 @@ class MiddlewareConfig(
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
MatrixoneConfig,
|
||||
):
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
"""ClickZetta Volume Storage Configuration"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickZettaVolumeStorageConfig(BaseSettings):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
CLICKZETTA_VOLUME_USERNAME: str | None = Field(
|
||||
description="Username for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_PASSWORD: str | None = Field(
|
||||
description="Password for ClickZetta Volume authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_INSTANCE: str | None = Field(
|
||||
description="ClickZetta instance identifier",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SERVICE: str = Field(
|
||||
description="ClickZetta service endpoint",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_WORKSPACE: str = Field(
|
||||
description="ClickZetta workspace name",
|
||||
default="quick_start",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_VCLUSTER: str = Field(
|
||||
description="ClickZetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_SCHEMA: str = Field(
|
||||
description="ClickZetta schema name",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TYPE: str = Field(
|
||||
description="ClickZetta volume type (table|user|external)",
|
||||
default="user",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_NAME: str | None = Field(
|
||||
description="ClickZetta volume name for external volumes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_TABLE_PREFIX: str = Field(
|
||||
description="Prefix for ClickZetta volume table names",
|
||||
default="dataset_",
|
||||
)
|
||||
|
||||
CLICKZETTA_VOLUME_DIFY_PREFIX: str = Field(
|
||||
description="Directory prefix for User Volume to organize Dify files",
|
||||
default="dify_km",
|
||||
)
|
||||
@@ -1,49 +0,0 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AnalyticdbConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||
"""
|
||||
|
||||
ANALYTICDB_KEY_ID: str | None = Field(
|
||||
default=None, description="The Access Key ID provided by Alibaba Cloud for API authentication."
|
||||
)
|
||||
ANALYTICDB_KEY_SECRET: str | None = Field(
|
||||
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure API access."
|
||||
)
|
||||
ANALYTICDB_REGION_ID: str | None = Field(
|
||||
default=None,
|
||||
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou', 'ap-southeast-1').",
|
||||
)
|
||||
ANALYTICDB_INSTANCE_ID: str | None = Field(
|
||||
default=None,
|
||||
description="The unique identifier of the AnalyticDB instance you want to connect to.",
|
||||
)
|
||||
ANALYTICDB_ACCOUNT: str | None = Field(
|
||||
default=None,
|
||||
description="The account name used to log in to the AnalyticDB instance"
|
||||
" (usually the initial account created with the instance).",
|
||||
)
|
||||
ANALYTICDB_PASSWORD: str | None = Field(
|
||||
default=None, description="The password associated with the AnalyticDB account for database authentication."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE: str | None = Field(
|
||||
default=None, description="The namespace within AnalyticDB for schema isolation (if using namespace feature)."
|
||||
)
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: str | None = Field(
|
||||
default=None,
|
||||
description="The password for accessing the specified namespace within the AnalyticDB instance"
|
||||
" (if namespace feature is enabled).",
|
||||
)
|
||||
ANALYTICDB_HOST: str | None = Field(
|
||||
default=None, description="The host of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_PORT: PositiveInt = Field(
|
||||
default=5432, description="The port of the AnalyticDB instance you want to connect to."
|
||||
)
|
||||
ANALYTICDB_MIN_CONNECTION: PositiveInt = Field(default=1, description="Min connection of the AnalyticDB database.")
|
||||
ANALYTICDB_MAX_CONNECTION: PositiveInt = Field(default=5, description="Max connection of the AnalyticDB database.")
|
||||
@@ -1,68 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseSettings):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
CLICKZETTA_USERNAME: str | None = Field(
|
||||
description="Username for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_PASSWORD: str | None = Field(
|
||||
description="Password for authenticating with Clickzetta Lakehouse",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_INSTANCE: str | None = Field(
|
||||
description="Clickzetta Lakehouse instance ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CLICKZETTA_SERVICE: str | None = Field(
|
||||
description="Clickzetta API service endpoint (e.g., 'api.clickzetta.com')",
|
||||
default="api.clickzetta.com",
|
||||
)
|
||||
|
||||
CLICKZETTA_WORKSPACE: str | None = Field(
|
||||
description="Clickzetta workspace name",
|
||||
default="default",
|
||||
)
|
||||
|
||||
CLICKZETTA_VCLUSTER: str | None = Field(
|
||||
description="Clickzetta virtual cluster name",
|
||||
default="default_ap",
|
||||
)
|
||||
|
||||
CLICKZETTA_SCHEMA: str | None = Field(
|
||||
description="Database schema name in Clickzetta",
|
||||
default="public",
|
||||
)
|
||||
|
||||
CLICKZETTA_BATCH_SIZE: int | None = Field(
|
||||
description="Batch size for bulk insert operations",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX: bool | None = Field(
|
||||
description="Enable inverted index for full-text search capabilities",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_TYPE: str | None = Field(
|
||||
description="Analyzer type for full-text search: keyword, english, chinese, unicode",
|
||||
default="chinese",
|
||||
)
|
||||
|
||||
CLICKZETTA_ANALYZER_MODE: str | None = Field(
|
||||
description="Analyzer mode for tokenization: max_word (fine-grained) or smart (intelligent)",
|
||||
default="smart",
|
||||
)
|
||||
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION: str | None = Field(
|
||||
description="Distance function for vector similarity: l2_distance or cosine_distance",
|
||||
default="cosine_distance",
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class TableStoreConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for TableStore.
|
||||
"""
|
||||
|
||||
TABLESTORE_ENDPOINT: str | None = Field(
|
||||
description="Endpoint address of the TableStore server (e.g. 'https://instance-name.cn-hangzhou.ots.aliyuncs.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_INSTANCE_NAME: str | None = Field(
|
||||
description="Instance name to access TableStore server (eg. 'instance-name')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_ACCESS_KEY_ID: str | None = Field(
|
||||
description="AccessKey id for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_ACCESS_KEY_SECRET: str | None = Field(
|
||||
description="AccessKey secret for the instance name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: bool = Field(
|
||||
description="Whether to normalize full-text search scores to [0, 1]",
|
||||
default=False,
|
||||
)
|
||||
@@ -242,7 +242,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
@@ -255,11 +254,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.SEEKDB,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
VectorType.IRIS,
|
||||
|
||||
@@ -297,6 +297,7 @@ class DatasetDocumentListApi(Resource):
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@ class TracingProviderEnum(StrEnum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
MLFLOW = "mlflow"
|
||||
DATABRICKS = "databricks"
|
||||
@@ -145,31 +144,6 @@ class OpikConfig(BaseTracingConfig):
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Weave tracing config.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
entity: str | None = None
|
||||
project: str
|
||||
endpoint: str = "https://trace.wandb.ai"
|
||||
host: str | None = None
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
|
||||
@@ -76,16 +76,6 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
"trace_instance": OpikDataTrace,
|
||||
}
|
||||
|
||||
case TracingProviderEnum.WEAVE:
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.weave_trace.weave_trace import WeaveDataTrace
|
||||
|
||||
return {
|
||||
"config_class": WeaveConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "entity", "endpoint", "host"],
|
||||
"trace_instance": WeaveDataTrace,
|
||||
}
|
||||
case TracingProviderEnum.ARIZE:
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
|
||||
from core.ops.entities.config_entity import ArizeConfig
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.ops.utils import replace_text_with_content
|
||||
|
||||
|
||||
class WeaveTokenUsage(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
output_tokens: int | None = None
|
||||
total_tokens: int | None = None
|
||||
|
||||
|
||||
class WeaveMultiModel(BaseModel):
|
||||
file_list: list[str] | None = Field(None, description="List of files")
|
||||
|
||||
|
||||
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
|
||||
id: str = Field(..., description="ID of the trace")
|
||||
op: str = Field(..., description="Name of the operation")
|
||||
inputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Inputs of the trace")
|
||||
outputs: Union[str, Mapping[str, Any], list, None] | None = Field(None, description="Outputs of the trace")
|
||||
attributes: Union[str, dict[str, Any], list, None] | None = Field(
|
||||
None, description="Metadata and attributes associated with trace"
|
||||
)
|
||||
exception: str | None = Field(None, description="Exception message of the trace")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
if v == {} or v is None:
|
||||
return v
|
||||
usage_metadata = {
|
||||
"input_tokens": values.get("input_tokens", 0),
|
||||
"output_tokens": values.get("output_tokens", 0),
|
||||
"total_tokens": values.get("total_tokens", 0),
|
||||
}
|
||||
file_list = values.get("file_list", [])
|
||||
if isinstance(v, str):
|
||||
if field_name == "inputs":
|
||||
return {
|
||||
"messages": {
|
||||
"role": "user",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
elif isinstance(v, list):
|
||||
data = {}
|
||||
if len(v) > 0 and isinstance(v[0], dict):
|
||||
# rename text to content
|
||||
v = replace_text_with_content(data=v)
|
||||
if field_name == "inputs":
|
||||
data = {
|
||||
"messages": [
|
||||
dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
|
||||
]
|
||||
if isinstance(v, list)
|
||||
else v,
|
||||
}
|
||||
elif field_name == "outputs":
|
||||
data = {
|
||||
"choices": {
|
||||
"role": "ai",
|
||||
"content": v,
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
"choices": {
|
||||
"role": "ai" if field_name == "outputs" else "user",
|
||||
"content": str(v),
|
||||
"usage_metadata": usage_metadata,
|
||||
"file_list": file_list,
|
||||
},
|
||||
}
|
||||
if isinstance(v, dict):
|
||||
v["usage_metadata"] = usage_metadata
|
||||
v["file_list"] = file_list
|
||||
return v
|
||||
return v
|
||||
@@ -1,523 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import wandb
|
||||
import weave
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from weave.trace_server.trace_server_interface import (
|
||||
CallEndReq,
|
||||
CallStartReq,
|
||||
EndedCallSchemaForInsert,
|
||||
StartedCallSchemaForInsert,
|
||||
SummaryInsertMap,
|
||||
TraceStatus,
|
||||
)
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import WeaveConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeaveDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
weave_config: WeaveConfig,
|
||||
):
|
||||
super().__init__(weave_config)
|
||||
self.weave_api_key = weave_config.api_key
|
||||
self.project_name = weave_config.project
|
||||
self.entity = weave_config.entity
|
||||
self.host = weave_config.host
|
||||
|
||||
# Login with API key first, including host if provided
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
logger.error("Failed to login to Weights & Biases with the provided API key")
|
||||
raise ValueError("Weave login failed")
|
||||
|
||||
# Then initialize weave client
|
||||
self.weave_client = weave.init(
|
||||
project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
|
||||
)
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.calls: dict[str, Any] = {}
|
||||
self.project_id = f"{self.weave_client.entity}/{self.weave_client.project}"
|
||||
|
||||
def get_project_url(
|
||||
self,
|
||||
):
|
||||
try:
|
||||
project_identifier = f"{self.entity}/{self.project_name}" if self.entity else self.project_name
|
||||
project_url = f"https://wandb.ai/{project_identifier}"
|
||||
return project_url
|
||||
except Exception as e:
|
||||
logger.debug("Weave get run url failed: %s", str(e))
|
||||
raise ValueError(f"Weave get run url failed: {str(e)}")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.debug("Trace info: %s", trace_info)
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
|
||||
if trace_info.message_id:
|
||||
message_attributes = trace_info.metadata
|
||||
message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
message_attributes["message_id"] = trace_info.message_id
|
||||
message_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
message_attributes["trace_id"] = trace_id
|
||||
message_attributes["start_time"] = trace_info.start_time
|
||||
message_attributes["end_time"] = trace_info.end_time
|
||||
message_attributes["tags"] = ["message", "workflow"]
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_info.message_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
total_tokens=trace_info.total_tokens,
|
||||
attributes=message_attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(message_run)
|
||||
|
||||
workflow_attributes = trace_info.metadata
|
||||
workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
workflow_attributes["trace_id"] = trace_id
|
||||
workflow_attributes["start_time"] = trace_info.start_time
|
||||
workflow_attributes["end_time"] = trace_info.end_time
|
||||
workflow_attributes["tags"] = ["dify_workflow"]
|
||||
|
||||
workflow_run = WeaveTraceModel(
|
||||
file_list=trace_info.file_list,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
id=trace_info.workflow_run_id,
|
||||
op=str(TraceTaskName.WORKFLOW_TRACE),
|
||||
inputs=dict(trace_info.workflow_run_inputs),
|
||||
outputs=dict(trace_info.workflow_run_outputs),
|
||||
attributes=workflow_attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
|
||||
self.start_call(workflow_run, parent_run_id=trace_info.message_id)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
# rearrange workflow_node_executions by starting time
|
||||
workflow_node_executions = sorted(workflow_node_executions, key=lambda x: x.created_at)
|
||||
|
||||
for node_execution in workflow_node_executions:
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
|
||||
else:
|
||||
inputs = node_execution.inputs or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0
|
||||
attributes = {str(k): v for k, v in execution_metadata.items()}
|
||||
attributes.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = node_execution.process_data or {}
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
attributes.update(
|
||||
{
|
||||
"ls_provider": process_data.get("model_provider", ""),
|
||||
"ls_model_name": process_data.get("model_name", ""),
|
||||
}
|
||||
)
|
||||
attributes["tags"] = ["node_execution"]
|
||||
attributes["start_time"] = created_at
|
||||
attributes["end_time"] = finished_at
|
||||
attributes["elapsed_time"] = elapsed_time
|
||||
attributes["workflow_run_id"] = trace_info.workflow_run_id
|
||||
attributes["trace_id"] = trace_id
|
||||
node_run = WeaveTraceModel(
|
||||
total_tokens=node_total_tokens,
|
||||
op=node_type,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
file_list=trace_info.file_list,
|
||||
attributes=attributes,
|
||||
id=node_execution_id,
|
||||
exception=None,
|
||||
)
|
||||
|
||||
self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
|
||||
self.finish_call(node_run)
|
||||
|
||||
self.finish_call(workflow_run)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: MessageFile | None = trace_info.message_file_data
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
attributes = trace_info.metadata
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = message_data.id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
attributes["user_id"] = user_id
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: EndUser | None = (
|
||||
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
attributes["end_user_id"] = end_user_id
|
||||
|
||||
attributes["message_id"] = message_id
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
attributes["tags"] = ["message", str(trace_info.conversation_mode)]
|
||||
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
message_run = WeaveTraceModel(
|
||||
id=trace_id,
|
||||
op=str(TraceTaskName.MESSAGE_TRACE),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
exception=trace_info.error,
|
||||
file_list=file_list,
|
||||
attributes=attributes,
|
||||
)
|
||||
self.start_call(message_run)
|
||||
|
||||
# create llm run parented to message run
|
||||
llm_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
input_tokens=trace_info.message_tokens,
|
||||
output_tokens=trace_info.answer_tokens,
|
||||
total_tokens=trace_info.total_tokens,
|
||||
op="llm",
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
file_list=[],
|
||||
exception=None,
|
||||
)
|
||||
self.start_call(
|
||||
llm_run,
|
||||
parent_run_id=trace_id,
|
||||
)
|
||||
self.finish_call(llm_run)
|
||||
self.finish_call(message_run)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["moderation"]
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
|
||||
attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
moderation_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.MODERATION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
self.start_call(moderation_run, parent_run_id=trace_id)
|
||||
self.finish_call(moderation_run)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["suggested_question"]
|
||||
attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
suggested_question_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.suggested_question,
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(suggested_question_run, parent_run_id=trace_id)
|
||||
self.finish_call(suggested_question_run)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
attributes = trace_info.metadata
|
||||
attributes["message_id"] = trace_info.message_id
|
||||
attributes["tags"] = ["dataset_retrieval"]
|
||||
attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
|
||||
attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
|
||||
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
dataset_retrieval_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs={"documents": trace_info.documents},
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(dataset_retrieval_run, parent_run_id=trace_id)
|
||||
self.finish_call(dataset_retrieval_run)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["tool", trace_info.tool_name]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
|
||||
message_id = message_id or None
|
||||
trace_id = trace_info.trace_id or message_id
|
||||
attributes["trace_id"] = trace_id
|
||||
|
||||
tool_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=trace_info.tool_name,
|
||||
inputs=trace_info.tool_inputs,
|
||||
outputs=trace_info.tool_outputs,
|
||||
file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
|
||||
attributes=attributes,
|
||||
exception=trace_info.error,
|
||||
)
|
||||
self.start_call(tool_run, parent_run_id=trace_id)
|
||||
self.finish_call(tool_run)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
attributes = trace_info.metadata
|
||||
attributes["tags"] = ["generate_name"]
|
||||
attributes["start_time"] = trace_info.start_time
|
||||
attributes["end_time"] = trace_info.end_time
|
||||
|
||||
name_run = WeaveTraceModel(
|
||||
id=str(uuid.uuid4()),
|
||||
op=str(TraceTaskName.GENERATE_NAME_TRACE),
|
||||
inputs=trace_info.inputs,
|
||||
outputs=trace_info.outputs,
|
||||
attributes=attributes,
|
||||
exception=getattr(trace_info, "error", None),
|
||||
file_list=[],
|
||||
)
|
||||
|
||||
self.start_call(name_run)
|
||||
self.finish_call(name_run)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
if self.host:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host)
|
||||
else:
|
||||
login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
|
||||
|
||||
if not login_status:
|
||||
raise ValueError("Weave login failed")
|
||||
else:
|
||||
logger.info("Weave login successful")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug("Weave API check failed: %s", str(e))
|
||||
raise ValueError(f"Weave API check failed: {str(e)}")
|
||||
|
||||
def _normalize_time(self, dt: datetime | None) -> datetime:
|
||||
if dt is None:
|
||||
return datetime.now(UTC)
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
def start_call(self, run_data: WeaveTraceModel, parent_run_id: str | None = None):
|
||||
inputs = run_data.inputs
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
elif not isinstance(inputs, dict):
|
||||
inputs = {"inputs": str(inputs)}
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
trace_id = attributes.get("trace_id") if isinstance(attributes, dict) else None
|
||||
if trace_id is None:
|
||||
trace_id = run_data.id
|
||||
|
||||
call_start_req = CallStartReq(
|
||||
start=StartedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
op_name=str(run_data.op),
|
||||
trace_id=trace_id,
|
||||
parent_id=parent_run_id,
|
||||
started_at=started_at,
|
||||
attributes=attributes,
|
||||
inputs=inputs,
|
||||
wb_user_id=None,
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_start(call_start_req)
|
||||
self.calls[run_data.id] = {"trace_id": trace_id, "parent_id": parent_run_id}
|
||||
|
||||
def finish_call(self, run_data: WeaveTraceModel):
|
||||
call_meta = self.calls.get(run_data.id)
|
||||
if not call_meta:
|
||||
raise ValueError(f"Call with id {run_data.id} not found")
|
||||
|
||||
attributes = run_data.attributes
|
||||
if attributes is None:
|
||||
attributes = {}
|
||||
elif not isinstance(attributes, dict):
|
||||
attributes = {"attributes": str(attributes)}
|
||||
|
||||
start_time = attributes.get("start_time") if isinstance(attributes, dict) else None
|
||||
end_time = attributes.get("end_time") if isinstance(attributes, dict) else None
|
||||
started_at = self._normalize_time(start_time if isinstance(start_time, datetime) else None)
|
||||
ended_at = self._normalize_time(end_time if isinstance(end_time, datetime) else None)
|
||||
elapsed_ms = int((ended_at - started_at).total_seconds() * 1000)
|
||||
if elapsed_ms < 0:
|
||||
elapsed_ms = 0
|
||||
|
||||
status_counts = {
|
||||
TraceStatus.SUCCESS: 0,
|
||||
TraceStatus.ERROR: 0,
|
||||
}
|
||||
if run_data.exception:
|
||||
status_counts[TraceStatus.ERROR] = 1
|
||||
else:
|
||||
status_counts[TraceStatus.SUCCESS] = 1
|
||||
|
||||
summary: dict[str, Any] = {
|
||||
"status_counts": status_counts,
|
||||
"weave": {"latency_ms": elapsed_ms},
|
||||
}
|
||||
|
||||
exception_str = str(run_data.exception) if run_data.exception else None
|
||||
|
||||
call_end_req = CallEndReq(
|
||||
end=EndedCallSchemaForInsert(
|
||||
project_id=self.project_id,
|
||||
id=run_data.id,
|
||||
ended_at=ended_at,
|
||||
exception=exception_str,
|
||||
output=run_data.outputs,
|
||||
summary=cast(SummaryInsertMap, summary),
|
||||
)
|
||||
)
|
||||
self.weave_client.server.call_end(call_end_req)
|
||||
@@ -1,104 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import (
|
||||
AnalyticdbVectorOpenAPI,
|
||||
AnalyticdbVectorOpenAPIConfig,
|
||||
)
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySql, AnalyticdbVectorBySqlConfig
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
api_config: AnalyticdbVectorOpenAPIConfig | None,
|
||||
sql_config: AnalyticdbVectorBySqlConfig | None,
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
|
||||
collection_name, api_config
|
||||
)
|
||||
else:
|
||||
if sql_config is None:
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.ANALYTICDB
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
self.analyticdb_vector.delete_by_ids(ids)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
self.analyticdb_vector.delete_by_metadata_field(key, value)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_vector(query_vector, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
|
||||
|
||||
def delete(self):
|
||||
self.analyticdb_vector.delete()
|
||||
|
||||
|
||||
class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> AnalyticdbVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name))
|
||||
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID or "",
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
)
|
||||
sqlConfig = None
|
||||
else:
|
||||
# implemented through sql
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
collection_name,
|
||||
apiConfig,
|
||||
sqlConfig,
|
||||
)
|
||||
@@ -1,321 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str | None = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client # type: ignore
|
||||
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException # type: ignore
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
if doc.metadata is not None:
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=where_clause,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -1,275 +0,0 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self):
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self):
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
conn = cur.connection
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS zhparser;")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise RuntimeError(
|
||||
"Failed to create zhparser extension. Please ensure it is available in your AnalyticDB."
|
||||
) from e
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
try:
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = "WHERE 1=1"
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} {where_clause} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
_, vector, score, page_content, metadata = record
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
where_clause = ""
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
|
||||
ORDER BY score DESC, id DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
_, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@@ -1,201 +0,0 @@
|
||||
# Clickzetta Vector Database Integration
|
||||
|
||||
This module provides integration with Clickzetta Lakehouse as a vector database for Dify.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vector Storage**: Store and retrieve high-dimensional vectors using Clickzetta's native VECTOR type
|
||||
- **Vector Search**: Efficient similarity search using HNSW algorithm
|
||||
- **Full-Text Search**: Leverage Clickzetta's inverted index for powerful text search capabilities
|
||||
- **Hybrid Search**: Combine vector similarity and full-text search for better results
|
||||
- **Multi-language Support**: Built-in support for Chinese, English, and Unicode text processing
|
||||
- **Scalable**: Leverage Clickzetta's distributed architecture for large-scale deployments
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Environment Variables
|
||||
|
||||
All seven configuration parameters are required:
|
||||
|
||||
```bash
|
||||
# Authentication
|
||||
CLICKZETTA_USERNAME=your_username
|
||||
CLICKZETTA_PASSWORD=your_password
|
||||
|
||||
# Instance configuration
|
||||
CLICKZETTA_INSTANCE=your_instance_id
|
||||
CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
CLICKZETTA_WORKSPACE=your_workspace
|
||||
CLICKZETTA_VCLUSTER=your_vcluster
|
||||
CLICKZETTA_SCHEMA=your_schema
|
||||
```
|
||||
|
||||
### Optional Configuration
|
||||
|
||||
```bash
|
||||
# Batch processing
|
||||
CLICKZETTA_BATCH_SIZE=100
|
||||
|
||||
# Full-text search configuration
|
||||
CLICKZETTA_ENABLE_INVERTED_INDEX=true
|
||||
CLICKZETTA_ANALYZER_TYPE=chinese # Options: keyword, english, chinese, unicode
|
||||
CLICKZETTA_ANALYZER_MODE=smart # Options: max_word, smart
|
||||
|
||||
# Vector search configuration
|
||||
CLICKZETTA_VECTOR_DISTANCE_FUNCTION=cosine_distance # Options: l2_distance, cosine_distance
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Set Clickzetta as the Vector Store
|
||||
|
||||
In your Dify configuration, set:
|
||||
|
||||
```bash
|
||||
VECTOR_STORE=clickzetta
|
||||
```
|
||||
|
||||
### 2. Table Structure
|
||||
|
||||
Clickzetta will automatically create tables with the following structure:
|
||||
|
||||
```sql
|
||||
CREATE TABLE <collection_name> (
|
||||
id STRING NOT NULL,
|
||||
content STRING NOT NULL,
|
||||
metadata JSON,
|
||||
vector VECTOR(FLOAT, <dimension>) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
-- Vector index for similarity search
|
||||
CREATE VECTOR INDEX idx_<collection_name>_vec
|
||||
ON TABLE <schema>.<collection_name>(vector)
|
||||
PROPERTIES (
|
||||
"distance.function" = "cosine_distance",
|
||||
"scalar.type" = "f32"
|
||||
);
|
||||
|
||||
-- Inverted index for full-text search (if enabled)
|
||||
CREATE INVERTED INDEX idx_<collection_name>_text
|
||||
ON <schema>.<collection_name>(content)
|
||||
PROPERTIES (
|
||||
"analyzer" = "chinese",
|
||||
"mode" = "smart"
|
||||
);
|
||||
```
|
||||
|
||||
## Full-Text Search Capabilities
|
||||
|
||||
Clickzetta supports advanced full-text search with multiple analyzers:
|
||||
|
||||
### Analyzer Types
|
||||
|
||||
1. **keyword**: No tokenization, treats the entire string as a single token
|
||||
|
||||
- Best for: Exact matching, IDs, codes
|
||||
|
||||
1. **english**: Designed for English text
|
||||
|
||||
- Features: Recognizes ASCII letters and numbers, converts to lowercase
|
||||
- Best for: English content
|
||||
|
||||
1. **chinese**: Chinese text tokenizer
|
||||
|
||||
- Features: Recognizes Chinese and English characters, removes punctuation
|
||||
- Best for: Chinese or mixed Chinese-English content
|
||||
|
||||
1. **unicode**: Multi-language tokenizer based on Unicode
|
||||
|
||||
- Features: Recognizes text boundaries in multiple languages
|
||||
- Best for: Multi-language content
|
||||
|
||||
### Analyzer Modes
|
||||
|
||||
- **max_word**: Fine-grained tokenization (more tokens)
|
||||
- **smart**: Intelligent tokenization (balanced)
|
||||
|
||||
### Full-Text Search Functions
|
||||
|
||||
- `MATCH_ALL(column, query)`: All terms must be present
|
||||
- `MATCH_ANY(column, query)`: At least one term must be present
|
||||
- `MATCH_PHRASE(column, query)`: Exact phrase matching
|
||||
- `MATCH_PHRASE_PREFIX(column, query)`: Phrase prefix matching
|
||||
- `MATCH_REGEXP(column, pattern)`: Regular expression matching
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Vector Search
|
||||
|
||||
1. **Adjust exploration factor** for accuracy vs speed trade-off:
|
||||
|
||||
```sql
|
||||
SET cz.vector.index.search.ef=64;
|
||||
```
|
||||
|
||||
1. **Use appropriate distance functions**:
|
||||
|
||||
- `cosine_distance`: Best for normalized embeddings (e.g., from language models)
|
||||
- `l2_distance`: Best for raw feature vectors
|
||||
|
||||
### Full-Text Search
|
||||
|
||||
1. **Choose the right analyzer**:
|
||||
|
||||
- Use `keyword` for exact matching
|
||||
- Use language-specific analyzers for better tokenization
|
||||
|
||||
1. **Combine with vector search**:
|
||||
|
||||
- Pre-filter with full-text search for better performance
|
||||
- Use hybrid search for improved relevance
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Issues
|
||||
|
||||
1. Verify all 7 required configuration parameters are set
|
||||
1. Check network connectivity to Clickzetta service
|
||||
1. Ensure the user has proper permissions on the schema
|
||||
|
||||
### Search Performance
|
||||
|
||||
1. Verify vector index exists:
|
||||
|
||||
```sql
|
||||
SHOW INDEX FROM <schema>.<table_name>;
|
||||
```
|
||||
|
||||
1. Check if vector index is being used:
|
||||
|
||||
```sql
|
||||
EXPLAIN SELECT ... WHERE l2_distance(...) < threshold;
|
||||
```
|
||||
|
||||
Look for `vector_index_search_type` in the execution plan.
|
||||
|
||||
### Full-Text Search Not Working
|
||||
|
||||
1. Verify inverted index is created
|
||||
1. Check analyzer configuration matches your content language
|
||||
1. Use `TOKENIZE()` function to test tokenization:
|
||||
```sql
|
||||
SELECT TOKENIZE('your text', map('analyzer', 'chinese', 'mode', 'smart'));
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. Vector operations don't support `ORDER BY` or `GROUP BY` directly on vector columns
|
||||
1. Full-text search relevance scores are not provided by Clickzetta
|
||||
1. Inverted index creation may fail for very large existing tables (continue without error)
|
||||
1. Index naming constraints:
|
||||
- Index names must be unique within a schema
|
||||
- Only one vector index can be created per column
|
||||
- The implementation uses timestamps to ensure unique index names
|
||||
1. A column can only have one vector index at a time
|
||||
|
||||
## References
|
||||
|
||||
- [Clickzetta Vector Search Documentation](https://yunqi.tech/documents/vector-search)
|
||||
- [Clickzetta Inverted Index Documentation](https://yunqi.tech/documents/inverted-index)
|
||||
- [Clickzetta SQL Functions](https://yunqi.tech/documents/sql-reference)
|
||||
@@ -1 +0,0 @@
|
||||
# Clickzetta Vector Database Integration for Dify
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,413 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import tablestore # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableStoreConfig(BaseModel):
|
||||
access_key_id: str | None = None
|
||||
access_key_secret: str | None = None
|
||||
instance_name: str | None = None
|
||||
endpoint: str | None = None
|
||||
normalize_full_text_bm25_score: bool | None = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ACCESS_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ACCESS_KEY_SECRET is required")
|
||||
if not values["instance_name"]:
|
||||
raise ValueError("config INSTANCE_NAME is required")
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config ENDPOINT is required")
|
||||
return values
|
||||
|
||||
|
||||
class TableStoreVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: TableStoreConfig):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
self._tablestore_client = tablestore.OTSClient(
|
||||
config.endpoint,
|
||||
config.access_key_id,
|
||||
config.access_key_secret,
|
||||
config.instance_name,
|
||||
)
|
||||
self._normalize_full_text_bm25_score = config.normalize_full_text_bm25_score
|
||||
self._table_name = f"{collection_name}"
|
||||
self._index_name = f"{collection_name}_idx"
|
||||
self._tags_field = f"{Field.METADATA_KEY}_tags"
|
||||
|
||||
def create_collection(self, embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
|
||||
def get_by_ids(self, ids: list[str]) -> list[Document]:
|
||||
docs = []
|
||||
request = BatchGetRowRequest()
|
||||
columns_to_get = [Field.METADATA_KEY, Field.CONTENT_KEY]
|
||||
rows_to_get = [[("id", _id)] for _id in ids]
|
||||
request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
|
||||
|
||||
result = self._tablestore_client.batch_get_row(request)
|
||||
table_result = result.get_result_by_table(self._table_name)
|
||||
for item in table_result:
|
||||
if item.is_ok and item.row:
|
||||
kv = {k: v for k, v, _ in item.row.attribute_columns}
|
||||
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
|
||||
return docs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.TABLESTORE
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
self.add_texts(documents=texts, embeddings=embeddings, **kwargs)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
|
||||
for i in range(len(documents)):
|
||||
self._write_row(
|
||||
primary_key=uuids[i],
|
||||
attributes={
|
||||
Field.CONTENT_KEY: documents[i].page_content,
|
||||
Field.VECTOR: embeddings[i],
|
||||
Field.METADATA_KEY: documents[i].metadata,
|
||||
},
|
||||
)
|
||||
return uuids
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
result = self._tablestore_client.get_row(
|
||||
table_name=self._table_name, primary_key=[("id", id)], columns_to_get=["id"]
|
||||
)
|
||||
assert isinstance(result, tuple | list)
|
||||
# Unpack the tuple result
|
||||
_, return_row, _ = result
|
||||
|
||||
return return_row is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
for id in ids:
|
||||
self._delete_row(id=id)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
return self._search_by_metadata(key, value)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
filtered_list = None
|
||||
if document_ids_filter:
|
||||
filtered_list = ["document_id=" + item for item in document_ids_filter]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._search_by_full_text(query, filtered_list, top_k, score_threshold)
|
||||
|
||||
def delete(self):
|
||||
self._delete_table_if_exist()
|
||||
|
||||
def _create_collection(self, dimension: int):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info("Collection %s already exists.", self._collection_name)
|
||||
return
|
||||
|
||||
self._create_table_if_not_exist()
|
||||
self._create_search_index_if_not_exist(dimension)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_table_if_not_exist(self):
|
||||
table_list = self._tablestore_client.list_table()
|
||||
if self._table_name in table_list:
|
||||
logger.info("Tablestore system table[%s] already exists", self._table_name)
|
||||
return None
|
||||
|
||||
schema_of_primary_key = [("id", "STRING")]
|
||||
table_meta = tablestore.TableMeta(self._table_name, schema_of_primary_key)
|
||||
table_options = tablestore.TableOptions()
|
||||
reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0))
|
||||
self._tablestore_client.create_table(table_meta, table_options, reserved_throughput)
|
||||
logger.info("Tablestore create table[%s] successfully.", self._table_name)
|
||||
|
||||
def _create_search_index_if_not_exist(self, dimension: int):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
if self._index_name in [t[1] for t in search_index_list]:
|
||||
logger.info("Tablestore system index[%s] already exists", self._index_name)
|
||||
return None
|
||||
|
||||
field_schemas = [
|
||||
tablestore.FieldSchema(
|
||||
Field.CONTENT_KEY,
|
||||
tablestore.FieldType.TEXT,
|
||||
analyzer=tablestore.AnalyzerType.MAXWORD,
|
||||
index=True,
|
||||
enable_sort_and_agg=False,
|
||||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.VECTOR,
|
||||
tablestore.FieldType.VECTOR,
|
||||
vector_options=tablestore.VectorOptions(
|
||||
data_type=tablestore.VectorDataType.VD_FLOAT_32,
|
||||
dimension=dimension,
|
||||
metric_type=tablestore.VectorMetricType.VM_COSINE,
|
||||
),
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
Field.METADATA_KEY,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
),
|
||||
tablestore.FieldSchema(
|
||||
self._tags_field,
|
||||
tablestore.FieldType.KEYWORD,
|
||||
index=True,
|
||||
store=False,
|
||||
is_array=True,
|
||||
),
|
||||
]
|
||||
|
||||
index_meta = tablestore.SearchIndexMeta(field_schemas)
|
||||
self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta)
|
||||
logger.info("Tablestore create system index[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_table_if_exist(self):
|
||||
search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name)
|
||||
assert isinstance(search_index_list, Iterable)
|
||||
for resp_tuple in search_index_list:
|
||||
self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1])
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
self._tablestore_client.delete_table(self._table_name)
|
||||
logger.info("Tablestore delete system table[%s] successfully.", self._index_name)
|
||||
|
||||
def _delete_search_index(self):
|
||||
self._tablestore_client.delete_search_index(self._table_name, self._index_name)
|
||||
logger.info("Tablestore delete index[%s] successfully.", self._index_name)
|
||||
|
||||
def _write_row(self, primary_key: str, attributes: dict[str, Any]):
|
||||
pk = [("id", primary_key)]
|
||||
|
||||
tags = []
|
||||
for key, value in attributes[Field.METADATA_KEY].items():
|
||||
tags.append(str(key) + "=" + str(value))
|
||||
|
||||
attribute_columns = [
|
||||
(Field.CONTENT_KEY, attributes[Field.CONTENT_KEY]),
|
||||
(Field.VECTOR, json.dumps(attributes[Field.VECTOR])),
|
||||
(
|
||||
Field.METADATA_KEY,
|
||||
json.dumps(attributes[Field.METADATA_KEY]),
|
||||
),
|
||||
(self._tags_field, json.dumps(tags)),
|
||||
]
|
||||
row = tablestore.Row(pk, attribute_columns)
|
||||
self._tablestore_client.put_row(self._table_name, row)
|
||||
|
||||
def _delete_row(self, id: str):
|
||||
primary_key = [("id", id)]
|
||||
row = tablestore.Row(primary_key)
|
||||
self._tablestore_client.delete_row(self._table_name, row, None)
|
||||
|
||||
def _search_by_metadata(self, key: str, value: str) -> list[str]:
|
||||
query = tablestore.SearchQuery(
|
||||
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
|
||||
limit=1000,
|
||||
get_total_count=False,
|
||||
)
|
||||
rows: list[str] = []
|
||||
next_token = None
|
||||
while True:
|
||||
if next_token is not None:
|
||||
query.next_token = next_token
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=query,
|
||||
columns_to_get=tablestore.ColumnsToGet(
|
||||
column_names=[Field.PRIMARY_KEY], return_type=tablestore.ColumnReturnType.SPECIFIED
|
||||
),
|
||||
)
|
||||
|
||||
if search_response is not None:
|
||||
rows.extend([row[0][0][1] for row in list(search_response.rows)])
|
||||
|
||||
if search_response is None or search_response.next_token == b"":
|
||||
break
|
||||
else:
|
||||
next_token = search_response.next_token
|
||||
|
||||
return rows
|
||||
|
||||
def _search_by_vector(
|
||||
self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
knn_vector_query = tablestore.KnnVectorQuery(
|
||||
field_name=Field.VECTOR,
|
||||
top_k=top_k,
|
||||
float32_query_vector=query_vector,
|
||||
)
|
||||
if document_ids_filter:
|
||||
knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
|
||||
|
||||
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
|
||||
search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
|
||||
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=search_query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
metadata["score"] = search_hit.score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_score_exp_decay(score: float, k: float = 0.15) -> float:
|
||||
"""
|
||||
Args:
|
||||
score: BM25 search score.
|
||||
k: decay factor, the larger the k, the steeper the low score end
|
||||
"""
|
||||
normalized_score = 1 - math.exp(-k * score)
|
||||
return max(0.0, min(1.0, normalized_score))
|
||||
|
||||
def _search_by_full_text(
|
||||
self, query: str, document_ids_filter: list[str] | None, top_k: int, score_threshold: float
|
||||
) -> list[Document]:
|
||||
bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
|
||||
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY))
|
||||
|
||||
if document_ids_filter:
|
||||
bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
|
||||
|
||||
search_query = tablestore.SearchQuery(
|
||||
query=bool_query,
|
||||
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
|
||||
limit=top_k,
|
||||
)
|
||||
search_response = self._tablestore_client.search(
|
||||
table_name=self._table_name,
|
||||
index_name=self._index_name,
|
||||
search_query=search_query,
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
score = None
|
||||
if self._normalize_full_text_bm25_score:
|
||||
score = self._normalize_score_exp_decay(search_hit.score)
|
||||
|
||||
# skip when score is below threshold and use normalize score
|
||||
if score and score <= score_threshold:
|
||||
continue
|
||||
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
||||
metadata_str = ots_column_map.get(Field.METADATA_KEY)
|
||||
metadata = json.loads(metadata_str) if metadata_str else {}
|
||||
|
||||
vector_str = ots_column_map.get(Field.VECTOR)
|
||||
vector = json.loads(vector_str) if vector_str else None
|
||||
|
||||
if score:
|
||||
metadata["score"] = score
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if self._normalize_full_text_bm25_score:
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
|
||||
class TableStoreVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.TABLESTORE, collection_name))
|
||||
|
||||
return TableStoreVector(
|
||||
collection_name=collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=dify_config.TABLESTORE_ENDPOINT,
|
||||
instance_name=dify_config.TABLESTORE_INSTANCE_NAME,
|
||||
access_key_id=dify_config.TABLESTORE_ACCESS_KEY_ID,
|
||||
access_key_secret=dify_config.TABLESTORE_ACCESS_KEY_SECRET,
|
||||
normalize_full_text_bm25_score=dify_config.TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE,
|
||||
),
|
||||
)
|
||||
@@ -135,10 +135,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
|
||||
|
||||
return OpenSearchVectorFactory
|
||||
case VectorType.ANALYTICDB:
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory
|
||||
|
||||
return AnalyticdbVectorFactory
|
||||
case VectorType.COUCHBASE:
|
||||
from core.rag.datasource.vdb.couchbase.couchbase_vector import CouchbaseVectorFactory
|
||||
|
||||
@@ -171,10 +167,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.opengauss.opengauss import OpenGaussFactory
|
||||
|
||||
return OpenGaussFactory
|
||||
case VectorType.TABLESTORE:
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
|
||||
|
||||
return TableStoreVectorFactory
|
||||
case VectorType.HUAWEI_CLOUD:
|
||||
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
|
||||
|
||||
@@ -183,10 +175,6 @@ class Vector:
|
||||
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
|
||||
|
||||
return MatrixoneVectorFactory
|
||||
case VectorType.CLICKZETTA:
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaVectorFactory
|
||||
|
||||
return ClickzettaVectorFactory
|
||||
case VectorType.IRIS:
|
||||
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import StrEnum
|
||||
|
||||
class VectorType(StrEnum):
|
||||
ALIBABACLOUD_MYSQL = "alibabacloud_mysql"
|
||||
ANALYTICDB = "analyticdb"
|
||||
CHROMA = "chroma"
|
||||
MILVUS = "milvus"
|
||||
MYSCALE = "myscale"
|
||||
@@ -29,9 +28,7 @@ class VectorType(StrEnum):
|
||||
OCEANBASE = "oceanbase"
|
||||
SEEKDB = "seekdb"
|
||||
OPENGAUSS = "opengauss"
|
||||
TABLESTORE = "tablestore"
|
||||
HUAWEI_CLOUD = "huawei_cloud"
|
||||
MATRIXONE = "matrixone"
|
||||
CLICKZETTA = "clickzetta"
|
||||
IRIS = "iris"
|
||||
HOLOGRES = "hologres"
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Final
|
||||
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
|
||||
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
|
||||
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
|
||||
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
|
||||
|
||||
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey
|
||||
@@ -47,7 +47,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]):
|
||||
|
||||
# Get trigger data passed when workflow was triggered
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
cast(WorkflowNodeExecutionMetadataKey, TRIGGER_INFO_METADATA_KEY): {
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": self.node_data.provider_id,
|
||||
"event_name": self.node_data.event_name,
|
||||
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
|
||||
|
||||
@@ -245,6 +245,9 @@ _END_STATE = frozenset(
|
||||
class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
"""
|
||||
Node Run Metadata Key.
|
||||
|
||||
Values in this enum are persisted as execution metadata and must stay in sync
|
||||
with every node that writes `NodeRunResult.metadata`.
|
||||
"""
|
||||
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
@@ -266,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
TRIGGER_INFO = "trigger_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@@ -256,9 +256,13 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@@ -69,19 +69,6 @@ class Storage:
|
||||
from extensions.storage.supabase_storage import SupabaseStorage
|
||||
|
||||
return SupabaseStorage
|
||||
case StorageType.CLICKZETTA_VOLUME:
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
def create_clickzetta_volume_storage():
|
||||
# ClickZettaVolumeConfig will automatically read from environment variables
|
||||
# and fallback to CLICKZETTA_* config if CLICKZETTA_VOLUME_* is not set
|
||||
volume_config = ClickZettaVolumeConfig()
|
||||
return ClickZettaVolumeStorage(volume_config)
|
||||
|
||||
return create_clickzetta_volume_storage
|
||||
case _:
|
||||
raise ValueError(f"unsupported storage type {storage_type}")
|
||||
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
from .clickzetta_volume_storage import ClickZettaVolumeStorage
|
||||
|
||||
__all__ = ["ClickZettaVolumeStorage"]
|
||||
@@ -1,527 +0,0 @@
|
||||
"""ClickZetta Volume Storage Implementation
|
||||
|
||||
This module provides storage backend using ClickZetta Volume functionality.
|
||||
Supports Table Volume, User Volume, and External Volume types.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import clickzetta
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
from .volume_permissions import VolumePermissionManager, check_volume_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickZettaVolumeConfig(BaseModel):
|
||||
"""Configuration for ClickZetta Volume storage."""
|
||||
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
instance: str = ""
|
||||
service: str = "api.clickzetta.com"
|
||||
workspace: str = "quick_start"
|
||||
vcluster: str = "default_ap"
|
||||
schema_name: str = "dify"
|
||||
volume_type: str = "table" # table|user|external
|
||||
volume_name: str | None = None # For external volumes
|
||||
table_prefix: str = "dataset_" # Prefix for table volume names
|
||||
dify_prefix: str = "dify_km" # Directory prefix for User Volume
|
||||
permission_check: bool = True # Enable/disable permission checking
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict):
|
||||
"""Validate the configuration values.
|
||||
|
||||
This method will first try to use CLICKZETTA_VOLUME_* environment variables,
|
||||
then fall back to CLICKZETTA_* environment variables (for vector DB config).
|
||||
"""
|
||||
|
||||
# Helper function to get environment variable with fallback
|
||||
def get_env_with_fallback(volume_key: str, fallback_key: str, default: str | None = None) -> str:
|
||||
# First try CLICKZETTA_VOLUME_* specific config
|
||||
volume_value = values.get(volume_key.lower().replace("clickzetta_volume_", ""))
|
||||
if volume_value:
|
||||
return str(volume_value)
|
||||
|
||||
# Then try environment variables
|
||||
volume_env = os.getenv(volume_key)
|
||||
if volume_env:
|
||||
return volume_env
|
||||
|
||||
# Fall back to existing CLICKZETTA_* config
|
||||
fallback_env = os.getenv(fallback_key)
|
||||
if fallback_env:
|
||||
return fallback_env
|
||||
|
||||
return default or ""
|
||||
|
||||
# Apply environment variables with fallback to existing CLICKZETTA_* config
|
||||
values.setdefault("username", get_env_with_fallback("CLICKZETTA_VOLUME_USERNAME", "CLICKZETTA_USERNAME"))
|
||||
values.setdefault("password", get_env_with_fallback("CLICKZETTA_VOLUME_PASSWORD", "CLICKZETTA_PASSWORD"))
|
||||
values.setdefault("instance", get_env_with_fallback("CLICKZETTA_VOLUME_INSTANCE", "CLICKZETTA_INSTANCE"))
|
||||
values.setdefault(
|
||||
"service", get_env_with_fallback("CLICKZETTA_VOLUME_SERVICE", "CLICKZETTA_SERVICE", "api.clickzetta.com")
|
||||
)
|
||||
values.setdefault(
|
||||
"workspace", get_env_with_fallback("CLICKZETTA_VOLUME_WORKSPACE", "CLICKZETTA_WORKSPACE", "quick_start")
|
||||
)
|
||||
values.setdefault(
|
||||
"vcluster", get_env_with_fallback("CLICKZETTA_VOLUME_VCLUSTER", "CLICKZETTA_VCLUSTER", "default_ap")
|
||||
)
|
||||
values.setdefault("schema_name", get_env_with_fallback("CLICKZETTA_VOLUME_SCHEMA", "CLICKZETTA_SCHEMA", "dify"))
|
||||
|
||||
# Volume-specific configurations (no fallback to vector DB config)
|
||||
values.setdefault("volume_type", os.getenv("CLICKZETTA_VOLUME_TYPE", "table"))
|
||||
values.setdefault("volume_name", os.getenv("CLICKZETTA_VOLUME_NAME"))
|
||||
values.setdefault("table_prefix", os.getenv("CLICKZETTA_VOLUME_TABLE_PREFIX", "dataset_"))
|
||||
values.setdefault("dify_prefix", os.getenv("CLICKZETTA_VOLUME_DIFY_PREFIX", "dify_km"))
|
||||
# Temporarily disable permission check feature, set directly to false
|
||||
values.setdefault("permission_check", False)
|
||||
|
||||
# Validate required fields
|
||||
if not values.get("username"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_USERNAME or CLICKZETTA_USERNAME is required")
|
||||
if not values.get("password"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_PASSWORD or CLICKZETTA_PASSWORD is required")
|
||||
if not values.get("instance"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_INSTANCE or CLICKZETTA_INSTANCE is required")
|
||||
|
||||
# Validate volume type
|
||||
volume_type = values["volume_type"]
|
||||
if volume_type not in ["table", "user", "external"]:
|
||||
raise ValueError("CLICKZETTA_VOLUME_TYPE must be one of: table, user, external")
|
||||
|
||||
if volume_type == "external" and not values.get("volume_name"):
|
||||
raise ValueError("CLICKZETTA_VOLUME_NAME is required for external volume type")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ClickZettaVolumeStorage(BaseStorage):
|
||||
"""ClickZetta Volume storage implementation."""
|
||||
|
||||
def __init__(self, config: ClickZettaVolumeConfig):
|
||||
"""Initialize ClickZetta Volume storage.
|
||||
|
||||
Args:
|
||||
config: ClickZetta Volume configuration
|
||||
"""
|
||||
self._config = config
|
||||
self._connection = None
|
||||
self._permission_manager: VolumePermissionManager | None = None
|
||||
self._init_connection()
|
||||
self._init_permission_manager()
|
||||
|
||||
logger.info("ClickZetta Volume storage initialized with type: %s", config.volume_type)
|
||||
|
||||
def _init_connection(self):
|
||||
"""Initialize ClickZetta connection."""
|
||||
try:
|
||||
self._connection = clickzetta.connect(
|
||||
username=self._config.username,
|
||||
password=self._config.password,
|
||||
instance=self._config.instance,
|
||||
service=self._config.service,
|
||||
workspace=self._config.workspace,
|
||||
vcluster=self._config.vcluster,
|
||||
schema=self._config.schema_name,
|
||||
)
|
||||
logger.debug("ClickZetta connection established")
|
||||
except Exception:
|
||||
logger.exception("Failed to connect to ClickZetta")
|
||||
raise
|
||||
|
||||
def _init_permission_manager(self):
|
||||
"""Initialize permission manager."""
|
||||
try:
|
||||
self._permission_manager = VolumePermissionManager(
|
||||
self._connection, self._config.volume_type, self._config.volume_name
|
||||
)
|
||||
logger.debug("Permission manager initialized")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize permission manager")
|
||||
raise
|
||||
|
||||
def _get_volume_path(self, filename: str, dataset_id: str | None = None) -> str:
|
||||
"""Get the appropriate volume path based on volume type."""
|
||||
if self._config.volume_type == "user":
|
||||
# Add dify prefix for User Volume to organize files
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
elif self._config.volume_type == "table":
|
||||
# Check if this should use User Volume (special directories)
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
# Use User Volume with dify prefix for special directories
|
||||
return f"{self._config.dify_prefix}/{filename}"
|
||||
|
||||
if dataset_id:
|
||||
return f"{self._config.table_prefix}{dataset_id}/{filename}"
|
||||
else:
|
||||
# Extract dataset_id from filename if not provided
|
||||
# Format: dataset_id/filename
|
||||
if "/" in filename:
|
||||
return filename
|
||||
else:
|
||||
raise ValueError("dataset_id is required for table volume or filename must include dataset_id/")
|
||||
elif self._config.volume_type == "external":
|
||||
return filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _get_volume_sql_prefix(self, dataset_id: str | None = None) -> str:
|
||||
"""Get SQL prefix for volume operations."""
|
||||
if self._config.volume_type == "user":
|
||||
return "USER VOLUME"
|
||||
elif self._config.volume_type == "table":
|
||||
# For Dify's current file storage pattern, most files are stored in
|
||||
# paths like "upload_files/tenant_id/uuid.ext", "tools/tenant_id/uuid.ext"
|
||||
# These should use USER VOLUME for better compatibility
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return "USER VOLUME"
|
||||
|
||||
# Only use TABLE VOLUME for actual dataset-specific paths
|
||||
# like "dataset_12345/file.pdf" or paths with dataset_ prefix
|
||||
if dataset_id:
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
else:
|
||||
# Default table name for generic operations
|
||||
table_name = "default_dataset"
|
||||
return f"TABLE VOLUME {table_name}"
|
||||
elif self._config.volume_type == "external":
|
||||
return f"VOLUME {self._config.volume_name}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported volume type: {self._config.volume_type}")
|
||||
|
||||
def _execute_sql(self, sql: str, fetch: bool = False):
|
||||
"""Execute SQL command."""
|
||||
try:
|
||||
if self._connection is None:
|
||||
raise RuntimeError("Connection not initialized")
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
if fetch:
|
||||
return cursor.fetchall()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("SQL execution failed: %s", sql)
|
||||
raise
|
||||
|
||||
def _ensure_table_volume_exists(self, dataset_id: str):
|
||||
"""Ensure table volume exists for the given dataset_id."""
|
||||
if self._config.volume_type != "table" or not dataset_id:
|
||||
return
|
||||
|
||||
# Skip for upload_files and other special directories that use USER VOLUME
|
||||
if dataset_id in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
return
|
||||
|
||||
table_name = f"{self._config.table_prefix}{dataset_id}"
|
||||
|
||||
try:
|
||||
# Check if table exists
|
||||
check_sql = f"SHOW TABLES LIKE '{table_name}'"
|
||||
result = self._execute_sql(check_sql, fetch=True)
|
||||
|
||||
if not result:
|
||||
# Create table with volume
|
||||
create_sql = f"""
|
||||
CREATE TABLE {table_name} (
|
||||
id INT PRIMARY KEY AUTO_INCREMENT,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_filename (filename)
|
||||
) WITH VOLUME
|
||||
"""
|
||||
self._execute_sql(create_sql)
|
||||
logger.info("Created table volume: %s", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create table volume %s: %s", table_name, e)
|
||||
# Don't raise exception, let the operation continue
|
||||
# The table might exist but not be visible due to permissions
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
"""Save data to ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
data: File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Ensure table volume exists (for table volumes)
|
||||
if dataset_id:
|
||||
self._ensure_table_volume_exists(dataset_id)
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "save", dataset_id)
|
||||
|
||||
# Write data to temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_file.write(data)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
try:
|
||||
# Upload to volume
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"PUT '{temp_file_path}' TO {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
logger.debug("File %s saved to ClickZetta Volume at path %s", filename, volume_path)
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
Path(temp_file_path).unlink(missing_ok=True)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
"""Load file content from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
# Check permissions (if enabled)
|
||||
if self._config.permission_check:
|
||||
# Skip permission check for special directories that use USER VOLUME
|
||||
if dataset_id not in ["upload_files", "temp", "cache", "tools", "website_files", "privkeys"]:
|
||||
if self._permission_manager is not None:
|
||||
check_volume_permission(self._permission_manager, "load_once", dataset_id)
|
||||
|
||||
# Download to temporary directory
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"GET {volume_prefix} FILE '{volume_path}' TO '{temp_dir}'"
|
||||
else:
|
||||
sql = f"GET {volume_prefix} FILE '{filename}' TO '{temp_dir}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
# Find the downloaded file (may be in subdirectories)
|
||||
downloaded_file = None
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for file in files:
|
||||
if file == filename or file == os.path.basename(filename):
|
||||
downloaded_file = Path(root) / file
|
||||
break
|
||||
if downloaded_file:
|
||||
break
|
||||
|
||||
if not downloaded_file or not downloaded_file.exists():
|
||||
raise FileNotFoundError(f"Downloaded file not found: {filename}")
|
||||
|
||||
content = downloaded_file.read_bytes()
|
||||
|
||||
logger.debug("File %s loaded from ClickZetta Volume", filename)
|
||||
return content
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
"""Load file as stream from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Yields:
|
||||
File content chunks
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
batch_size = 4096
|
||||
stream = BytesIO(content)
|
||||
|
||||
while chunk := stream.read(batch_size):
|
||||
yield chunk
|
||||
|
||||
logger.debug("File %s loaded as stream from ClickZetta Volume", filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
"""Download file from ClickZetta Volume to local path.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
target_filepath: Local target file path
|
||||
"""
|
||||
content = self.load_once(filename)
|
||||
|
||||
Path(target_filepath).write_bytes(content)
|
||||
|
||||
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
"""Check if file exists in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
|
||||
Returns:
|
||||
True if file exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{volume_path}$'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} REGEXP = '^{filename}$'"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
exists = len(rows) > 0 if rows else False
|
||||
logger.debug("File %s exists check: %s", filename, exists)
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.warning("Error checking file existence for %s: %s", filename, e)
|
||||
return False
|
||||
|
||||
def delete(self, filename: str):
|
||||
"""Delete file from ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
filename: File path in volume
|
||||
"""
|
||||
if not self.exists(filename):
|
||||
logger.debug("File %s not found, skip delete", filename)
|
||||
return
|
||||
|
||||
# Extract dataset_id from filename if present
|
||||
dataset_id = None
|
||||
if "/" in filename and self._config.volume_type == "table":
|
||||
parts = filename.split("/", 1)
|
||||
if parts[0].startswith(self._config.table_prefix):
|
||||
dataset_id = parts[0][len(self._config.table_prefix) :]
|
||||
filename = parts[1]
|
||||
else:
|
||||
dataset_id = parts[0]
|
||||
filename = parts[1]
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# Get the actual volume path (may include dify_km prefix)
|
||||
volume_path = self._get_volume_path(filename, dataset_id)
|
||||
|
||||
# For User Volume, use the full path with dify_km prefix
|
||||
if volume_prefix == "USER VOLUME":
|
||||
sql = f"REMOVE {volume_prefix} FILE '{volume_path}'"
|
||||
else:
|
||||
sql = f"REMOVE {volume_prefix} FILE '{filename}'"
|
||||
|
||||
self._execute_sql(sql)
|
||||
|
||||
logger.debug("File %s deleted from ClickZetta Volume", filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
"""Scan files and directories in ClickZetta Volume.
|
||||
|
||||
Args:
|
||||
path: Path to scan (dataset_id for table volumes)
|
||||
files: Include files in results
|
||||
directories: Include directories in results
|
||||
|
||||
Returns:
|
||||
List of file/directory paths
|
||||
"""
|
||||
try:
|
||||
# For table volumes, path is treated as dataset_id
|
||||
dataset_id = None
|
||||
if self._config.volume_type == "table":
|
||||
dataset_id = path
|
||||
path = "" # Root of the table volume
|
||||
|
||||
volume_prefix = self._get_volume_sql_prefix(dataset_id)
|
||||
|
||||
# For User Volume, add dify prefix to path
|
||||
if volume_prefix == "USER VOLUME":
|
||||
if path:
|
||||
scan_path = f"{self._config.dify_prefix}/{path}"
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{scan_path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{self._config.dify_prefix}'"
|
||||
else:
|
||||
if path:
|
||||
sql = f"LIST {volume_prefix} SUBDIRECTORY '{path}'"
|
||||
else:
|
||||
sql = f"LIST {volume_prefix}"
|
||||
|
||||
rows = self._execute_sql(sql, fetch=True)
|
||||
|
||||
result = []
|
||||
if rows:
|
||||
for row in rows:
|
||||
file_path = row[0] # relative_path column
|
||||
|
||||
# For User Volume, remove dify prefix from results
|
||||
dify_prefix_with_slash = f"{self._config.dify_prefix}/"
|
||||
if volume_prefix == "USER VOLUME" and file_path.startswith(dify_prefix_with_slash):
|
||||
file_path = file_path[len(dify_prefix_with_slash) :] # Remove prefix
|
||||
|
||||
if files and not file_path.endswith("/") or directories and file_path.endswith("/"):
|
||||
result.append(file_path)
|
||||
|
||||
logger.debug("Scanned %d items in path %s", len(result), path)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error scanning path %s", path)
|
||||
return []
|
||||
@@ -1,518 +0,0 @@
|
||||
"""ClickZetta Volume file lifecycle management
|
||||
|
||||
This module provides file lifecycle management features including version control,
|
||||
automatic cleanup, backup and restore.
|
||||
Supports complete lifecycle management for knowledge base files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import operator
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStatus(StrEnum):
|
||||
"""File status enumeration"""
|
||||
|
||||
ACTIVE = auto() # Active status
|
||||
ARCHIVED = auto() # Archived
|
||||
DELETED = auto() # Deleted (soft delete)
|
||||
BACKUP = auto() # Backup file
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""File metadata"""
|
||||
|
||||
filename: str
|
||||
size: int | None
|
||||
created_at: datetime
|
||||
modified_at: datetime
|
||||
version: int | None
|
||||
status: FileStatus
|
||||
checksum: str | None = None
|
||||
tags: dict[str, str] | None = None
|
||||
parent_version: int | None = None
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary format"""
|
||||
data = asdict(self)
|
||||
data["created_at"] = self.created_at.isoformat()
|
||||
data["modified_at"] = self.modified_at.isoformat()
|
||||
data["status"] = self.status.value
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> FileMetadata:
|
||||
"""Create instance from dictionary"""
|
||||
data = data.copy()
|
||||
data["created_at"] = datetime.fromisoformat(data["created_at"])
|
||||
data["modified_at"] = datetime.fromisoformat(data["modified_at"])
|
||||
data["status"] = FileStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class FileLifecycleManager:
|
||||
"""File lifecycle manager"""
|
||||
|
||||
def __init__(self, storage, dataset_id: str | None = None):
|
||||
"""Initialize lifecycle manager
|
||||
|
||||
Args:
|
||||
storage: ClickZetta Volume storage instance
|
||||
dataset_id: Dataset ID (for Table Volume)
|
||||
"""
|
||||
self._storage = storage
|
||||
self._dataset_id = dataset_id
|
||||
self._metadata_file = ".dify_file_metadata.json"
|
||||
self._version_prefix = ".versions/"
|
||||
self._backup_prefix = ".backups/"
|
||||
self._deleted_prefix = ".deleted/"
|
||||
|
||||
# Get permission manager (if exists)
|
||||
self._permission_manager: Any | None = getattr(storage, "_permission_manager", None)
|
||||
|
||||
def save_with_lifecycle(self, filename: str, data: bytes, tags: dict[str, str] | None = None) -> FileMetadata:
|
||||
"""Save file and manage lifecycle
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
data: File content
|
||||
tags: File tags
|
||||
|
||||
Returns:
|
||||
File metadata
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "save"):
|
||||
from .volume_permissions import VolumePermissionError
|
||||
|
||||
raise VolumePermissionError(
|
||||
f"Permission denied for lifecycle save operation on file: {filename}",
|
||||
operation="save",
|
||||
volume_type=getattr(self._storage, "_config", {}).get("volume_type", "unknown"),
|
||||
dataset_id=self._dataset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Check if old version exists
|
||||
metadata_dict = self._load_metadata()
|
||||
current_metadata = metadata_dict.get(filename)
|
||||
|
||||
# 2. If old version exists, create version backup
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata)
|
||||
|
||||
# 3. Calculate file information
|
||||
now = datetime.now()
|
||||
checksum = self._calculate_checksum(data)
|
||||
new_version = (current_metadata["version"] + 1) if current_metadata else 1
|
||||
|
||||
# 4. Save new file
|
||||
self._storage.save(filename, data)
|
||||
|
||||
# 5. Create metadata
|
||||
created_at = now
|
||||
parent_version = None
|
||||
|
||||
if current_metadata:
|
||||
# If created_at is string, convert to datetime
|
||||
if isinstance(current_metadata["created_at"], str):
|
||||
created_at = datetime.fromisoformat(current_metadata["created_at"])
|
||||
else:
|
||||
created_at = current_metadata["created_at"]
|
||||
parent_version = current_metadata["version"]
|
||||
|
||||
file_metadata = FileMetadata(
|
||||
filename=filename,
|
||||
size=len(data),
|
||||
created_at=created_at,
|
||||
modified_at=now,
|
||||
version=new_version,
|
||||
status=FileStatus.ACTIVE,
|
||||
checksum=checksum,
|
||||
tags=tags or {},
|
||||
parent_version=parent_version,
|
||||
)
|
||||
|
||||
# 6. Update metadata
|
||||
metadata_dict[filename] = file_metadata.to_dict()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s saved with lifecycle management, version %s", filename, new_version)
|
||||
return file_metadata
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to save file with lifecycle")
|
||||
raise
|
||||
|
||||
def get_file_metadata(self, filename: str) -> FileMetadata | None:
|
||||
"""Get file metadata
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File metadata, returns None if not exists
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
return FileMetadata.from_dict(metadata_dict[filename])
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to get file metadata for %s", filename)
|
||||
return None
|
||||
|
||||
def list_file_versions(self, filename: str) -> list[FileMetadata]:
|
||||
"""List all versions of a file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
File version list, sorted by version number
|
||||
"""
|
||||
try:
|
||||
versions = []
|
||||
|
||||
# Get current version
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
versions.append(current_metadata)
|
||||
|
||||
# Get historical versions
|
||||
try:
|
||||
version_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
for file_path in version_files:
|
||||
if file_path.startswith(f"{self._version_prefix}{filename}.v"):
|
||||
# Parse version number
|
||||
version_str = file_path.split(".v")[-1].split(".")[0]
|
||||
try:
|
||||
_ = int(version_str)
|
||||
# Simplified processing here, should actually read metadata from version file
|
||||
# Temporarily create basic metadata information
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception:
|
||||
# If cannot scan version files, only return current version
|
||||
logger.exception("Failed to scan version files for %s", filename)
|
||||
|
||||
return sorted(versions, key=lambda x: x.version or 0, reverse=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to list file versions for %s", filename)
|
||||
return []
|
||||
|
||||
def restore_version(self, filename: str, version: int) -> bool:
|
||||
"""Restore file to specified version
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
version: Version number to restore
|
||||
|
||||
Returns:
|
||||
Whether restore succeeded
|
||||
"""
|
||||
try:
|
||||
version_filename = f"{self._version_prefix}{filename}.v{version}"
|
||||
|
||||
# Check if version file exists
|
||||
if not self._storage.exists(version_filename):
|
||||
logger.warning("Version %s of %s not found", version, filename)
|
||||
return False
|
||||
|
||||
# Read version file content
|
||||
version_data = self._storage.load_once(version_filename)
|
||||
|
||||
# Save current version as backup
|
||||
current_metadata = self.get_file_metadata(filename)
|
||||
if current_metadata:
|
||||
self._create_version_backup(filename, current_metadata.to_dict())
|
||||
|
||||
# Restore file
|
||||
self.save_with_lifecycle(filename, version_data, {"restored_from": str(version)})
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to restore %s to version %s", filename, version)
|
||||
return False
|
||||
|
||||
def archive_file(self, filename: str) -> bool:
|
||||
"""Archive file
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether archive succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "archive"):
|
||||
logger.warning("Permission denied for archive operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Update file status to archived
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename not in metadata_dict:
|
||||
logger.warning("File %s not found in metadata", filename)
|
||||
return False
|
||||
|
||||
metadata_dict[filename]["status"] = FileStatus.ARCHIVED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s archived successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to archive file %s", filename)
|
||||
return False
|
||||
|
||||
def soft_delete_file(self, filename: str) -> bool:
|
||||
"""Soft delete file (move to deleted directory)
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
|
||||
Returns:
|
||||
Whether delete succeeded
|
||||
"""
|
||||
# Permission check
|
||||
if not self._check_permission(filename, "delete"):
|
||||
logger.warning("Permission denied for soft delete operation on file: %s", filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self._storage.exists(filename):
|
||||
logger.warning("File %s not found", filename)
|
||||
return False
|
||||
|
||||
# Read file content
|
||||
file_data = self._storage.load_once(filename)
|
||||
|
||||
# Move to deleted directory
|
||||
deleted_filename = f"{self._deleted_prefix}{filename}.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self._storage.save(deleted_filename, file_data)
|
||||
|
||||
# Delete original file
|
||||
self._storage.delete(filename)
|
||||
|
||||
# Update metadata
|
||||
metadata_dict = self._load_metadata()
|
||||
if filename in metadata_dict:
|
||||
metadata_dict[filename]["status"] = FileStatus.DELETED
|
||||
metadata_dict[filename]["modified_at"] = datetime.now().isoformat()
|
||||
self._save_metadata(metadata_dict)
|
||||
|
||||
logger.info("File %s soft deleted successfully", filename)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to soft delete file %s", filename)
|
||||
return False
|
||||
|
||||
def cleanup_old_versions(self, max_versions: int = 5, max_age_days: int = 30) -> int:
|
||||
"""Cleanup old version files
|
||||
|
||||
Args:
|
||||
max_versions: Maximum number of versions to keep
|
||||
max_age_days: Maximum retention days for version files
|
||||
|
||||
Returns:
|
||||
Number of files cleaned
|
||||
"""
|
||||
try:
|
||||
cleaned_count = 0
|
||||
|
||||
# Get all version files
|
||||
try:
|
||||
all_files = self._storage.scan(self._dataset_id or "", files=True)
|
||||
version_files = [f for f in all_files if f.startswith(self._version_prefix)]
|
||||
|
||||
# Group by file
|
||||
file_versions: dict[str, list[tuple[int, str]]] = {}
|
||||
for version_file in version_files:
|
||||
# Parse filename and version
|
||||
parts = version_file[len(self._version_prefix) :].split(".v")
|
||||
if len(parts) >= 2:
|
||||
base_filename = parts[0]
|
||||
version_part = parts[1].split(".")[0]
|
||||
try:
|
||||
version_num = int(version_part)
|
||||
if base_filename not in file_versions:
|
||||
file_versions[base_filename] = []
|
||||
file_versions[base_filename].append((version_num, version_file))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Cleanup old versions for each file
|
||||
for base_filename, versions in file_versions.items():
|
||||
# Sort by version number
|
||||
versions.sort(key=operator.itemgetter(0), reverse=True)
|
||||
|
||||
# Keep the newest max_versions versions, delete the rest
|
||||
if len(versions) > max_versions:
|
||||
to_delete = versions[max_versions:]
|
||||
for version_num, version_file in to_delete:
|
||||
self._storage.delete(version_file)
|
||||
cleaned_count += 1
|
||||
logger.debug("Cleaned old version: %s", version_file)
|
||||
|
||||
logger.info("Cleaned %d old version files", cleaned_count)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not scan for version files: %s", e)
|
||||
|
||||
return cleaned_count
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup old versions")
|
||||
return 0
|
||||
|
||||
def get_storage_statistics(self) -> dict[str, Any]:
|
||||
"""Get storage statistics
|
||||
|
||||
Returns:
|
||||
Storage statistics dictionary
|
||||
"""
|
||||
try:
|
||||
metadata_dict = self._load_metadata()
|
||||
|
||||
stats: dict[str, Any] = {
|
||||
"total_files": len(metadata_dict),
|
||||
"active_files": 0,
|
||||
"archived_files": 0,
|
||||
"deleted_files": 0,
|
||||
"total_size": 0,
|
||||
"versions_count": 0,
|
||||
"oldest_file": None,
|
||||
"newest_file": None,
|
||||
}
|
||||
|
||||
oldest_date = None
|
||||
newest_date = None
|
||||
|
||||
for filename, metadata in metadata_dict.items():
|
||||
file_meta = FileMetadata.from_dict(metadata)
|
||||
|
||||
# Count file status
|
||||
if file_meta.status == FileStatus.ACTIVE:
|
||||
stats["active_files"] = (stats["active_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.ARCHIVED:
|
||||
stats["archived_files"] = (stats["archived_files"] or 0) + 1
|
||||
elif file_meta.status == FileStatus.DELETED:
|
||||
stats["deleted_files"] = (stats["deleted_files"] or 0) + 1
|
||||
|
||||
# Count size
|
||||
stats["total_size"] = (stats["total_size"] or 0) + (file_meta.size or 0)
|
||||
|
||||
# Count versions
|
||||
stats["versions_count"] = (stats["versions_count"] or 0) + (file_meta.version or 0)
|
||||
|
||||
# Find newest and oldest files
|
||||
if oldest_date is None or file_meta.created_at < oldest_date:
|
||||
oldest_date = file_meta.created_at
|
||||
stats["oldest_file"] = filename
|
||||
|
||||
if newest_date is None or file_meta.modified_at > newest_date:
|
||||
newest_date = file_meta.modified_at
|
||||
stats["newest_file"] = filename
|
||||
|
||||
return stats
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to get storage statistics")
|
||||
return {}
|
||||
|
||||
def _create_version_backup(self, filename: str, metadata: dict):
|
||||
"""Create version backup"""
|
||||
try:
|
||||
# Read current file content
|
||||
current_data = self._storage.load_once(filename)
|
||||
|
||||
# Save as version file
|
||||
version_filename = f"{self._version_prefix}{filename}.v{metadata['version']}"
|
||||
self._storage.save(version_filename, current_data)
|
||||
|
||||
logger.debug("Created version backup: %s", version_filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create version backup for %s: %s", filename, e)
|
||||
|
||||
def _load_metadata(self) -> dict[str, Any]:
|
||||
"""Load metadata file"""
|
||||
try:
|
||||
if self._storage.exists(self._metadata_file):
|
||||
metadata_content = self._storage.load_once(self._metadata_file)
|
||||
result = json.loads(metadata_content.decode("utf-8"))
|
||||
return dict(result) if result else {}
|
||||
else:
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load metadata: %s", e)
|
||||
return {}
|
||||
|
||||
def _save_metadata(self, metadata_dict: dict):
|
||||
"""Save metadata file"""
|
||||
try:
|
||||
metadata_content = json.dumps(metadata_dict, indent=2, ensure_ascii=False)
|
||||
self._storage.save(self._metadata_file, metadata_content.encode("utf-8"))
|
||||
logger.debug("Metadata saved successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to save metadata")
|
||||
raise
|
||||
|
||||
def _calculate_checksum(self, data: bytes) -> str:
|
||||
"""Calculate file checksum"""
|
||||
import hashlib
|
||||
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
def _check_permission(self, filename: str, operation: str) -> bool:
|
||||
"""Check file operation permission
|
||||
|
||||
Args:
|
||||
filename: File name
|
||||
operation: Operation type
|
||||
|
||||
Returns:
|
||||
True if permission granted, False otherwise
|
||||
"""
|
||||
# If no permission manager, allow by default
|
||||
if not self._permission_manager:
|
||||
return True
|
||||
|
||||
try:
|
||||
# Map operation type to permission
|
||||
operation_mapping = {
|
||||
"save": "save",
|
||||
"load": "load_once",
|
||||
"delete": "delete",
|
||||
"archive": "delete", # Archive requires delete permission
|
||||
"restore": "save", # Restore requires write permission
|
||||
"cleanup": "delete", # Cleanup requires delete permission
|
||||
"read": "load_once",
|
||||
"write": "save",
|
||||
}
|
||||
|
||||
mapped_operation = operation_mapping.get(operation, operation)
|
||||
|
||||
# Check permission
|
||||
result = self._permission_manager.validate_operation(mapped_operation, self._dataset_id)
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed for %s operation %s", filename, operation)
|
||||
# Safe default: deny access when permission check fails
|
||||
return False
|
||||
@@ -1,649 +0,0 @@
|
||||
"""ClickZetta Volume permission management mechanism
|
||||
|
||||
This module provides Volume permission checking, validation and management features.
|
||||
According to ClickZetta's permission model, different Volume types have different permission requirements.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VolumePermission(StrEnum):
|
||||
"""Volume permission type enumeration"""
|
||||
|
||||
READ = "SELECT" # Corresponds to ClickZetta's SELECT permission
|
||||
WRITE = "INSERT,UPDATE,DELETE" # Corresponds to ClickZetta's write permissions
|
||||
LIST = "SELECT" # Listing files requires SELECT permission
|
||||
DELETE = "INSERT,UPDATE,DELETE" # Deleting files requires write permissions
|
||||
USAGE = "USAGE" # Basic permission required for External Volume
|
||||
|
||||
|
||||
class VolumePermissionManager:
|
||||
"""Volume permission manager"""
|
||||
|
||||
def __init__(self, connection_or_config, volume_type: str | None = None, volume_name: str | None = None):
|
||||
"""Initialize permission manager
|
||||
|
||||
Args:
|
||||
connection_or_config: ClickZetta connection object or configuration dictionary
|
||||
volume_type: Volume type (user|table|external)
|
||||
volume_name: Volume name (for external volume)
|
||||
"""
|
||||
# Support two initialization methods: connection object or configuration dictionary
|
||||
if isinstance(connection_or_config, dict):
|
||||
# Create connection from configuration dictionary
|
||||
import clickzetta
|
||||
|
||||
config = connection_or_config
|
||||
self._connection = clickzetta.connect(
|
||||
username=config.get("username"),
|
||||
password=config.get("password"),
|
||||
instance=config.get("instance"),
|
||||
service=config.get("service"),
|
||||
workspace=config.get("workspace"),
|
||||
vcluster=config.get("vcluster"),
|
||||
schema=config.get("schema") or config.get("database"),
|
||||
)
|
||||
self._volume_type = config.get("volume_type", volume_type)
|
||||
self._volume_name = config.get("volume_name", volume_name)
|
||||
else:
|
||||
# Use connection object directly
|
||||
self._connection = connection_or_config
|
||||
self._volume_type = volume_type
|
||||
self._volume_name = volume_name
|
||||
|
||||
if not self._connection:
|
||||
raise ValueError("Valid connection or config is required")
|
||||
if not self._volume_type:
|
||||
raise ValueError("volume_type is required")
|
||||
|
||||
self._permission_cache: dict[str, set[str]] = {}
|
||||
self._current_username = None # Will get current username from connection
|
||||
|
||||
def check_permission(self, operation: VolumePermission, dataset_id: str | None = None) -> bool:
|
||||
"""Check if user has permission to perform specific operation
|
||||
|
||||
Args:
|
||||
operation: Type of operation to perform
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
if self._volume_type == "user":
|
||||
return self._check_user_volume_permission(operation)
|
||||
elif self._volume_type == "table":
|
||||
return self._check_table_volume_permission(operation, dataset_id)
|
||||
elif self._volume_type == "external":
|
||||
return self._check_external_volume_permission(operation)
|
||||
else:
|
||||
logger.warning("Unknown volume type: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission check failed")
|
||||
return False
|
||||
|
||||
def _check_user_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check User Volume permission
|
||||
|
||||
User Volume permission rules:
|
||||
- User has full permissions on their own User Volume
|
||||
- As long as user can connect to ClickZetta, they have basic User Volume permissions by default
|
||||
- Focus more on connection authentication rather than complex permission checking
|
||||
"""
|
||||
try:
|
||||
# Get current username
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check basic connection status
|
||||
with self._connection.cursor() as cursor:
|
||||
# Simple connection test, if query can be executed user has basic permissions
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
logger.debug(
|
||||
"User Volume permission check for %s, operation %s: granted (basic connection verified)",
|
||||
current_user,
|
||||
operation.name,
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"User Volume permission check failed: cannot verify basic connection for %s", current_user
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("User Volume permission check failed")
|
||||
# For User Volume, if permission check fails, it might be a configuration issue,
|
||||
# provide friendlier error message
|
||||
logger.info("User Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _check_table_volume_permission(self, operation: VolumePermission, dataset_id: str | None) -> bool:
|
||||
"""Check Table Volume permission
|
||||
|
||||
Table Volume permission rules:
|
||||
- Table Volume permissions inherit from corresponding table permissions
|
||||
- SELECT permission -> can READ/LIST files
|
||||
- INSERT,UPDATE,DELETE permissions -> can WRITE/DELETE files
|
||||
"""
|
||||
if not dataset_id:
|
||||
logger.warning("dataset_id is required for table volume permission check")
|
||||
return False
|
||||
|
||||
table_name = f"dataset_{dataset_id}" if not dataset_id.startswith("dataset_") else dataset_id
|
||||
|
||||
try:
|
||||
# Check table permissions
|
||||
permissions = self._get_table_permissions(table_name)
|
||||
required_permissions = set(operation.value.split(","))
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"Table Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
table_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("Table volume permission check failed for %s", table_name)
|
||||
return False
|
||||
|
||||
def _check_external_volume_permission(self, operation: VolumePermission) -> bool:
|
||||
"""Check External Volume permission
|
||||
|
||||
External Volume permission rules:
|
||||
- Try to get permissions for External Volume
|
||||
- If permission check fails, perform fallback verification
|
||||
- For development environment, provide more lenient permission checking
|
||||
"""
|
||||
if not self._volume_name:
|
||||
logger.warning("volume_name is required for external volume permission check")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check External Volume permissions
|
||||
permissions = self._get_external_volume_permissions(self._volume_name)
|
||||
|
||||
# External Volume permission mapping: determine required permissions based on operation type
|
||||
required_permissions = set()
|
||||
|
||||
if operation in [VolumePermission.READ, VolumePermission.LIST]:
|
||||
required_permissions.add("read")
|
||||
elif operation in [VolumePermission.WRITE, VolumePermission.DELETE]:
|
||||
required_permissions.add("write")
|
||||
|
||||
# Check if has all required permissions
|
||||
has_permission = required_permissions.issubset(permissions)
|
||||
|
||||
logger.debug(
|
||||
"External Volume permission check for %s, operation %s: required=%s, has=%s, granted=%s",
|
||||
self._volume_name,
|
||||
operation.name,
|
||||
required_permissions,
|
||||
permissions,
|
||||
has_permission,
|
||||
)
|
||||
|
||||
# If permission check fails, try fallback verification
|
||||
if not has_permission:
|
||||
logger.info("Direct permission check failed for %s, trying fallback verification", self._volume_name)
|
||||
|
||||
# Fallback verification: try listing Volume to verify basic access permissions
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == self._volume_name:
|
||||
logger.info("Fallback verification successful for %s", self._volume_name)
|
||||
return True
|
||||
except Exception as fallback_e:
|
||||
logger.warning("Fallback verification failed for %s: %s", self._volume_name, fallback_e)
|
||||
|
||||
return has_permission
|
||||
|
||||
except Exception:
|
||||
logger.exception("External volume permission check failed for %s", self._volume_name)
|
||||
logger.info("External Volume permission check failed, but permission checking is disabled in this version")
|
||||
return False
|
||||
|
||||
def _get_table_permissions(self, table_name: str) -> set[str]:
|
||||
"""Get user permissions for specified table
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this table
|
||||
"""
|
||||
cache_key = f"table:{table_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find permissions for this table
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
object_type = grant[1].upper() if len(grant) > 1 else ""
|
||||
object_name = grant[2] if len(grant) > 2 else ""
|
||||
|
||||
# Check if it's permission for this table
|
||||
if (
|
||||
object_type == "TABLE"
|
||||
and object_name == table_name
|
||||
or object_type == "SCHEMA"
|
||||
and object_name in table_name
|
||||
):
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
# If no explicit permissions found, try executing a simple query to verify permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} LIMIT 1")
|
||||
permissions.add("SELECT")
|
||||
except Exception:
|
||||
logger.debug("Cannot query table %s, no SELECT permission", table_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check table permissions for %s: %s", table_name, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_current_username(self) -> str:
|
||||
"""Get current username"""
|
||||
if self._current_username:
|
||||
return self._current_username
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
self._current_username = result[0]
|
||||
return str(self._current_username)
|
||||
except Exception:
|
||||
logger.exception("Failed to get current username")
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_user_permissions(self, username: str) -> set[str]:
|
||||
"""Get user's basic permission set"""
|
||||
cache_key = f"user_permissions:{username}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check current user permissions
|
||||
cursor.execute("SHOW GRANTS")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
# Parse permission results, find user's basic permissions
|
||||
for grant in grants:
|
||||
if len(grant) >= 3: # Typical format: (privilege, object_type, object_name, ...)
|
||||
privilege = grant[0].upper()
|
||||
_ = grant[1].upper() if len(grant) > 1 else ""
|
||||
|
||||
# Collect all relevant permissions
|
||||
if privilege in ["SELECT", "INSERT", "UPDATE", "DELETE", "ALL"]:
|
||||
if privilege == "ALL":
|
||||
permissions.update(["SELECT", "INSERT", "UPDATE", "DELETE"])
|
||||
else:
|
||||
permissions.add(privilege)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check user permissions for %s: %s", username, e)
|
||||
# Safe default: deny access when permission check fails
|
||||
pass
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def _get_external_volume_permissions(self, volume_name: str) -> set[str]:
|
||||
"""Get user permissions for specified External Volume
|
||||
|
||||
Args:
|
||||
volume_name: External Volume name
|
||||
|
||||
Returns:
|
||||
Set of user permissions for this Volume
|
||||
"""
|
||||
cache_key = f"external_volume:{volume_name}"
|
||||
|
||||
if cache_key in self._permission_cache:
|
||||
return self._permission_cache[cache_key]
|
||||
|
||||
permissions = set()
|
||||
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
# Use correct ClickZetta syntax to check Volume permissions
|
||||
logger.info("Checking permissions for volume: %s", volume_name)
|
||||
cursor.execute(f"SHOW GRANTS ON VOLUME {volume_name}")
|
||||
grants = cursor.fetchall()
|
||||
|
||||
logger.info("Raw grants result for %s: %s", volume_name, grants)
|
||||
|
||||
# Parse permission results
|
||||
# Format: (granted_type, privilege, conditions, granted_on, object_name, granted_to,
|
||||
# grantee_name, grantor_name, grant_option, granted_time)
|
||||
for grant in grants:
|
||||
logger.info("Processing grant: %s", grant)
|
||||
if len(grant) >= 5:
|
||||
granted_type = grant[0]
|
||||
privilege = grant[1].upper()
|
||||
granted_on = grant[3]
|
||||
object_name = grant[4]
|
||||
|
||||
logger.info(
|
||||
"Grant details - type: %s, privilege: %s, granted_on: %s, object_name: %s",
|
||||
granted_type,
|
||||
privilege,
|
||||
granted_on,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# Check if it's permission for this Volume or hierarchical permission
|
||||
if (
|
||||
granted_type == "PRIVILEGE" and granted_on == "VOLUME" and object_name.endswith(volume_name)
|
||||
) or (granted_type == "OBJECT_HIERARCHY" and granted_on == "VOLUME"):
|
||||
logger.info("Matching grant found for %s", volume_name)
|
||||
|
||||
if "READ" in privilege:
|
||||
permissions.add("read")
|
||||
logger.info("Added READ permission for %s", volume_name)
|
||||
if "WRITE" in privilege:
|
||||
permissions.add("write")
|
||||
logger.info("Added WRITE permission for %s", volume_name)
|
||||
if "ALTER" in privilege:
|
||||
permissions.add("alter")
|
||||
logger.info("Added ALTER permission for %s", volume_name)
|
||||
if privilege == "ALL":
|
||||
permissions.update(["read", "write", "alter"])
|
||||
logger.info("Added ALL permissions for %s", volume_name)
|
||||
|
||||
logger.info("Final permissions for %s: %s", volume_name, permissions)
|
||||
|
||||
# If no explicit permissions found, try viewing Volume list to verify basic permissions
|
||||
if not permissions:
|
||||
try:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
permissions.add("read") # At least has read permission
|
||||
logger.debug("Volume %s found in SHOW VOLUMES, assuming read permission", volume_name)
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Cannot access volume %s, no basic permission", volume_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not check external volume permissions for %s: %s", volume_name, e)
|
||||
# When permission check fails, try basic Volume access verification
|
||||
try:
|
||||
with self._connection.cursor() as cursor:
|
||||
cursor.execute("SHOW VOLUMES")
|
||||
volumes = cursor.fetchall()
|
||||
for volume in volumes:
|
||||
if len(volume) > 0 and volume[0] == volume_name:
|
||||
logger.info("Basic volume access verified for %s", volume_name)
|
||||
permissions.add("read")
|
||||
permissions.add("write") # Assume has write permission
|
||||
break
|
||||
except Exception as basic_e:
|
||||
logger.warning("Basic volume access check failed for %s: %s", volume_name, basic_e)
|
||||
# Last fallback: assume basic permissions
|
||||
permissions.add("read")
|
||||
|
||||
# Cache permission information
|
||||
self._permission_cache[cache_key] = permissions
|
||||
return permissions
|
||||
|
||||
def clear_permission_cache(self):
|
||||
"""Clear permission cache"""
|
||||
self._permission_cache.clear()
|
||||
logger.debug("Permission cache cleared")
|
||||
|
||||
@property
|
||||
def volume_type(self) -> str | None:
|
||||
"""Get the volume type."""
|
||||
return self._volume_type
|
||||
|
||||
def get_permission_summary(self, dataset_id: str | None = None) -> dict[str, bool]:
|
||||
"""Get permission summary
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID (for table volume)
|
||||
|
||||
Returns:
|
||||
Permission summary dictionary
|
||||
"""
|
||||
summary = {}
|
||||
|
||||
for operation in VolumePermission:
|
||||
summary[operation.name.lower()] = self.check_permission(operation, dataset_id)
|
||||
|
||||
return summary
|
||||
|
||||
def check_inherited_permission(self, file_path: str, operation: VolumePermission) -> bool:
|
||||
"""Check permission inheritance for file path
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
operation: Operation to perform
|
||||
|
||||
Returns:
|
||||
True if user has permission, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Parse file path
|
||||
path_parts = file_path.strip("/").split("/")
|
||||
|
||||
if not path_parts:
|
||||
logger.warning("Invalid file path for permission inheritance check")
|
||||
return False
|
||||
|
||||
# For Table Volume, first layer is dataset_id
|
||||
if self._volume_type == "table":
|
||||
if len(path_parts) < 1:
|
||||
return False
|
||||
|
||||
dataset_id = path_parts[0]
|
||||
|
||||
# Check permissions for dataset
|
||||
has_dataset_permission = self.check_permission(operation, dataset_id)
|
||||
|
||||
if not has_dataset_permission:
|
||||
logger.debug("Permission denied for dataset %s", dataset_id)
|
||||
return False
|
||||
|
||||
# Check path traversal attack
|
||||
if self._contains_path_traversal(file_path):
|
||||
logger.warning("Path traversal attack detected: %s", file_path)
|
||||
return False
|
||||
|
||||
# Check if accessing sensitive directory
|
||||
if self._is_sensitive_path(file_path):
|
||||
logger.warning("Access to sensitive path denied: %s", file_path)
|
||||
return False
|
||||
|
||||
logger.debug("Permission inherited for path %s", file_path)
|
||||
return True
|
||||
|
||||
elif self._volume_type == "user":
|
||||
# User Volume permission inheritance
|
||||
current_user = self._get_current_username()
|
||||
|
||||
# Check if attempting to access other user's directory
|
||||
if len(path_parts) > 1 and path_parts[0] != current_user:
|
||||
logger.warning("User %s attempted to access %s's directory", current_user, path_parts[0])
|
||||
return False
|
||||
|
||||
# Check basic permissions
|
||||
return self.check_permission(operation)
|
||||
|
||||
elif self._volume_type == "external":
|
||||
# External Volume permission inheritance
|
||||
# Check permissions for External Volume
|
||||
return self.check_permission(operation)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown volume type for permission inheritance: %s", self._volume_type)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception("Permission inheritance check failed")
|
||||
return False
|
||||
|
||||
def _contains_path_traversal(self, file_path: str) -> bool:
|
||||
"""Check if path contains path traversal attack"""
|
||||
# Check common path traversal patterns
|
||||
traversal_patterns = [
|
||||
"../",
|
||||
"..\\",
|
||||
"..%2f",
|
||||
"..%2F",
|
||||
"..%5c",
|
||||
"..%5C",
|
||||
"%2e%2e%2f",
|
||||
"%2e%2e%5c",
|
||||
"....//",
|
||||
"....\\\\",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
for pattern in traversal_patterns:
|
||||
if pattern in file_path_lower:
|
||||
return True
|
||||
|
||||
# Check absolute path
|
||||
if file_path.startswith("/") or file_path.startswith("\\"):
|
||||
return True
|
||||
|
||||
# Check Windows drive path
|
||||
if len(file_path) >= 2 and file_path[1] == ":":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_sensitive_path(self, file_path: str) -> bool:
|
||||
"""Check if path is sensitive path"""
|
||||
sensitive_patterns = [
|
||||
"passwd",
|
||||
"shadow",
|
||||
"hosts",
|
||||
"config",
|
||||
"secrets",
|
||||
"private",
|
||||
"key",
|
||||
"certificate",
|
||||
"cert",
|
||||
"ssl",
|
||||
"database",
|
||||
"backup",
|
||||
"dump",
|
||||
"log",
|
||||
"tmp",
|
||||
]
|
||||
|
||||
file_path_lower = file_path.lower()
|
||||
|
||||
return any(pattern in file_path_lower for pattern in sensitive_patterns)
|
||||
|
||||
def validate_operation(self, operation: str, dataset_id: str | None = None) -> bool:
|
||||
"""Validate operation permission
|
||||
|
||||
Args:
|
||||
operation: Operation name (save|load|exists|delete|scan)
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Returns:
|
||||
True if operation is allowed, False otherwise
|
||||
"""
|
||||
operation_mapping = {
|
||||
"save": VolumePermission.WRITE,
|
||||
"load": VolumePermission.READ,
|
||||
"load_once": VolumePermission.READ,
|
||||
"load_stream": VolumePermission.READ,
|
||||
"download": VolumePermission.READ,
|
||||
"exists": VolumePermission.READ,
|
||||
"delete": VolumePermission.DELETE,
|
||||
"scan": VolumePermission.LIST,
|
||||
}
|
||||
|
||||
if operation not in operation_mapping:
|
||||
logger.warning("Unknown operation: %s", operation)
|
||||
return False
|
||||
|
||||
volume_permission = operation_mapping[operation]
|
||||
return self.check_permission(volume_permission, dataset_id)
|
||||
|
||||
|
||||
class VolumePermissionError(Exception):
|
||||
"""Volume permission error exception"""
|
||||
|
||||
def __init__(self, message: str, operation: str, volume_type: str, dataset_id: str | None = None):
|
||||
self.operation = operation
|
||||
self.volume_type = volume_type
|
||||
self.dataset_id = dataset_id
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_volume_permission(permission_manager: VolumePermissionManager, operation: str, dataset_id: str | None = None):
|
||||
"""Permission check decorator function
|
||||
|
||||
Args:
|
||||
permission_manager: Permission manager
|
||||
operation: Operation name
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Raises:
|
||||
VolumePermissionError: If no permission
|
||||
"""
|
||||
if not permission_manager.validate_operation(operation, dataset_id):
|
||||
error_message = f"Permission denied for operation '{operation}' on {permission_manager.volume_type} volume"
|
||||
if dataset_id:
|
||||
error_message += f" (dataset: {dataset_id})"
|
||||
|
||||
raise VolumePermissionError(
|
||||
error_message,
|
||||
operation=operation,
|
||||
volume_type=permission_manager.volume_type or "unknown",
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
@@ -5,7 +5,6 @@ class StorageType(StrEnum):
|
||||
ALIYUN_OSS = "aliyun-oss"
|
||||
AZURE_BLOB = "azure-blob"
|
||||
BAIDU_OBS = "baidu-obs"
|
||||
CLICKZETTA_VOLUME = "clickzetta-volume"
|
||||
GOOGLE_STORAGE = "google-storage"
|
||||
HUAWEI_OBS = "huawei-obs"
|
||||
LOCAL = "local"
|
||||
|
||||
@@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@@ -22,14 +22,14 @@ from sqlalchemy import (
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from dify_graph.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
)
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
@@ -936,8 +936,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata:
|
||||
datasource_info = execution_metadata["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
elif self.node_type == TRIGGER_PLUGIN_NODE_TYPE and TRIGGER_INFO_METADATA_KEY in execution_metadata:
|
||||
trigger_info = execution_metadata[TRIGGER_INFO_METADATA_KEY] or {}
|
||||
elif (
|
||||
self.node_type == TRIGGER_PLUGIN_NODE_TYPE
|
||||
and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata
|
||||
):
|
||||
trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {}
|
||||
provider_id = trigger_info.get("provider_id")
|
||||
if provider_id:
|
||||
extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.1"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
@@ -89,7 +89,6 @@ dependencies = [
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
"bleach~=6.2.0",
|
||||
]
|
||||
@@ -101,6 +100,7 @@ packages = []
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb"]
|
||||
constraint-dependencies = ["cryptography>=46.0.5"]
|
||||
package = false
|
||||
|
||||
[dependency-groups]
|
||||
@@ -202,11 +202,8 @@ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.4.3",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.14.1",
|
||||
"clickzetta-connector-python>=0.8.102",
|
||||
"couchbase~=4.5.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
@@ -218,7 +215,6 @@ vdb = [
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.1",
|
||||
"tcvectordb~=2.0.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
|
||||
@@ -45,11 +45,8 @@ core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
@@ -62,7 +59,6 @@ core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
@@ -144,8 +140,6 @@ extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
|
||||
@@ -28,7 +28,6 @@
|
||||
"baidubce.auth.bce_credentials",
|
||||
"baidubce.bce_client_configuration",
|
||||
"baidubce.services.bos.bos_client",
|
||||
"clickzetta",
|
||||
"google.cloud",
|
||||
"obs",
|
||||
"qcloud_cos",
|
||||
@@ -52,4 +51,4 @@
|
||||
"reportAttributeAccessIssue": "hint",
|
||||
"pythonVersion": "3.11",
|
||||
"pythonPlatform": "All"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,15 +86,6 @@ class OpsService:
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
|
||||
if tracing_provider == "weave" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
try:
|
||||
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
|
||||
new_decrypt_tracing_config.update({"project_url": project_url})
|
||||
except Exception:
|
||||
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
|
||||
|
||||
if tracing_provider == "aliyun" and (
|
||||
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
|
||||
):
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Integration tests for ClickZetta Volume Storage."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.storage.clickzetta_volume.clickzetta_volume_storage import (
|
||||
ClickZettaVolumeConfig,
|
||||
ClickZettaVolumeStorage,
|
||||
)
|
||||
|
||||
|
||||
class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
"""Test cases for ClickZetta Volume Storage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.config = ClickZettaVolumeConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_pass"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "uat-api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema_name=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
volume_type="table",
|
||||
table_prefix="test_dataset_",
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_user_volume_operations(self):
|
||||
"""Test basic operations with User Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "user"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations
|
||||
test_filename = "test_file.txt"
|
||||
test_content = b"Hello, ClickZetta Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test streaming
|
||||
stream_content = b""
|
||||
for chunk in storage.load_stream(test_filename):
|
||||
stream_content += chunk
|
||||
assert stream_content == test_content
|
||||
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
files = storage.scan("", files=True, directories=False)
|
||||
assert test_filename in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("CLICKZETTA_USERNAME"), reason="ClickZetta credentials not provided")
|
||||
def test_table_volume_operations(self):
|
||||
"""Test basic operations with Table Volume."""
|
||||
config = self.config
|
||||
config.volume_type = "table"
|
||||
|
||||
storage = ClickZettaVolumeStorage(config)
|
||||
|
||||
# Test file operations with dataset_id
|
||||
dataset_id = "12345"
|
||||
test_filename = f"{dataset_id}/test_file.txt"
|
||||
test_content = b"Hello, Table Volume!"
|
||||
|
||||
# Save file
|
||||
storage.save(test_filename, test_content)
|
||||
|
||||
# Check if file exists
|
||||
assert storage.exists(test_filename)
|
||||
|
||||
# Load file
|
||||
loaded_content = storage.load_once(test_filename)
|
||||
assert loaded_content == test_content
|
||||
|
||||
# Test scan for dataset
|
||||
files = storage.scan(dataset_id, files=True, directories=False)
|
||||
assert "test_file.txt" in files
|
||||
|
||||
# Delete file
|
||||
storage.delete(test_filename)
|
||||
assert not storage.exists(test_filename)
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test configuration validation."""
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="", # Empty username should fail
|
||||
password="pass",
|
||||
instance="instance",
|
||||
)
|
||||
|
||||
# Test invalid volume type
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(username="user", password="pass", instance="instance", volume_type="invalid_type")
|
||||
|
||||
# Test external volume without volume_name
|
||||
with pytest.raises(ValueError):
|
||||
ClickZettaVolumeConfig(
|
||||
username="user",
|
||||
password="pass",
|
||||
instance="instance",
|
||||
volume_type="external",
|
||||
# Missing volume_name
|
||||
)
|
||||
|
||||
def test_volume_path_generation(self):
|
||||
"""Test volume path generation for different types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume path
|
||||
path = storage._get_volume_path("test.txt", "12345")
|
||||
assert path == "test_dataset_12345/test.txt"
|
||||
|
||||
# Test path with existing dataset_id prefix
|
||||
path = storage._get_volume_path("12345/test.txt")
|
||||
assert path == "12345/test.txt"
|
||||
|
||||
# Test user volume
|
||||
storage._config.volume_type = "user"
|
||||
path = storage._get_volume_path("test.txt")
|
||||
assert path == "test.txt"
|
||||
|
||||
def test_sql_prefix_generation(self):
|
||||
"""Test SQL prefix generation for different volume types."""
|
||||
storage = ClickZettaVolumeStorage(self.config)
|
||||
|
||||
# Test table volume SQL prefix
|
||||
prefix = storage._get_volume_sql_prefix("12345")
|
||||
assert prefix == "TABLE VOLUME test_dataset_12345"
|
||||
|
||||
# Test user volume SQL prefix
|
||||
storage._config.volume_type = "user"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "USER VOLUME"
|
||||
|
||||
# Test external volume SQL prefix
|
||||
storage._config.volume_type = "external"
|
||||
storage._config.volume_name = "my_external_volume"
|
||||
prefix = storage._get_volume_sql_prefix()
|
||||
assert prefix == "VOLUME my_external_volume"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,49 +0,0 @@
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig
|
||||
from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
|
||||
class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def __init__(self, config_type: str):
|
||||
super().__init__()
|
||||
# Analyticdb requires collection_name length less than 60.
|
||||
# it's ok for normal usage.
|
||||
self.collection_name = self.collection_name.replace("_test", "")
|
||||
if config_type == "sql":
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=AnalyticdbVectorBySqlConfig(
|
||||
host="test_host",
|
||||
port=5432,
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
),
|
||||
api_config=None,
|
||||
)
|
||||
else:
|
||||
self.vector = AnalyticdbVector(
|
||||
collection_name=self.collection_name,
|
||||
sql_config=None,
|
||||
api_config=AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id="test_key_id",
|
||||
access_key_secret="test_key_secret",
|
||||
region_id="test_region",
|
||||
instance_id="test_id",
|
||||
account="test_account",
|
||||
account_password="test_passwd",
|
||||
namespace="difytest_namespace",
|
||||
collection="difytest_collection",
|
||||
namespace_password="test_passwd",
|
||||
),
|
||||
)
|
||||
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest("api").run_all_tests()
|
||||
AnalyticdbVectorTest("sql").run_all_tests()
|
||||
@@ -1,25 +0,0 @@
|
||||
# Clickzetta Integration Tests
|
||||
|
||||
## Running Tests
|
||||
|
||||
To run the Clickzetta integration tests, you need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export CLICKZETTA_USERNAME=your_username
|
||||
export CLICKZETTA_PASSWORD=your_password
|
||||
export CLICKZETTA_INSTANCE=your_instance
|
||||
export CLICKZETTA_SERVICE=api.clickzetta.com
|
||||
export CLICKZETTA_WORKSPACE=your_workspace
|
||||
export CLICKZETTA_VCLUSTER=your_vcluster
|
||||
export CLICKZETTA_SCHEMA=dify
|
||||
```
|
||||
|
||||
Then run the tests:
|
||||
|
||||
```bash
|
||||
pytest api/tests/integration_tests/vdb/clickzetta/
|
||||
```
|
||||
|
||||
## Security Note
|
||||
|
||||
Never commit credentials to the repository. Always use environment variables or secure credential management systems.
|
||||
@@ -1,223 +0,0 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.datasource.vdb.clickzetta.clickzetta_vector import ClickzettaConfig, ClickzettaVector
|
||||
from core.rag.models.document import Document
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis
|
||||
|
||||
|
||||
class TestClickzettaVector(AbstractVectorTest):
|
||||
"""
|
||||
Test cases for Clickzetta vector database integration.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Create a Clickzetta vector store instance for testing."""
|
||||
# Skip test if Clickzetta credentials are not configured
|
||||
if not os.getenv("CLICKZETTA_USERNAME"):
|
||||
pytest.skip("CLICKZETTA_USERNAME is not configured")
|
||||
if not os.getenv("CLICKZETTA_PASSWORD"):
|
||||
pytest.skip("CLICKZETTA_PASSWORD is not configured")
|
||||
if not os.getenv("CLICKZETTA_INSTANCE"):
|
||||
pytest.skip("CLICKZETTA_INSTANCE is not configured")
|
||||
|
||||
config = ClickzettaConfig(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", ""),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", ""),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", ""),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "quick_start"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default_ap"),
|
||||
schema=os.getenv("CLICKZETTA_SCHEMA", "dify_test"),
|
||||
batch_size=10, # Small batch size for testing
|
||||
enable_inverted_index=True,
|
||||
analyzer_type="chinese",
|
||||
analyzer_mode="smart",
|
||||
vector_distance_function="cosine_distance",
|
||||
)
|
||||
|
||||
with setup_mock_redis():
|
||||
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)
|
||||
|
||||
yield vector
|
||||
|
||||
# Cleanup: delete the test collection
|
||||
with contextlib.suppress(Exception):
|
||||
vector.delete()
|
||||
|
||||
def test_clickzetta_vector_basic_operations(self, vector_store):
|
||||
"""Test basic CRUD operations on Clickzetta vector store."""
|
||||
# Prepare test data
|
||||
texts = [
|
||||
"这是第一个测试文档,包含一些中文内容。",
|
||||
"This is the second test document with English content.",
|
||||
"第三个文档混合了English和中文内容。",
|
||||
]
|
||||
embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.6, 0.7, 0.8],
|
||||
[0.9, 1.0, 1.1, 1.2],
|
||||
]
|
||||
documents = [
|
||||
Document(page_content=text, metadata={"doc_id": f"doc_{i}", "source": "test"})
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Test create (initial insert)
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test text_exists
|
||||
assert vector_store.text_exists("doc_0")
|
||||
assert not vector_store.text_exists("doc_999")
|
||||
|
||||
# Test search_by_vector
|
||||
query_vector = [0.1, 0.2, 0.3, 0.4]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=2)
|
||||
assert len(results) > 0
|
||||
assert results[0].page_content == texts[0] # Should match the first document
|
||||
|
||||
# Test search_by_full_text (Chinese)
|
||||
results = vector_store.search_by_full_text("中文", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with Chinese content
|
||||
|
||||
# Test search_by_full_text (English)
|
||||
results = vector_store.search_by_full_text("English", top_k=3)
|
||||
assert len(results) >= 2 # Should find documents with English content
|
||||
|
||||
# Test delete_by_ids
|
||||
vector_store.delete_by_ids(["doc_0"])
|
||||
assert not vector_store.text_exists("doc_0")
|
||||
assert vector_store.text_exists("doc_1")
|
||||
|
||||
# Test delete_by_metadata_field
|
||||
vector_store.delete_by_metadata_field("source", "test")
|
||||
assert not vector_store.text_exists("doc_1")
|
||||
assert not vector_store.text_exists("doc_2")
|
||||
|
||||
def test_clickzetta_vector_advanced_search(self, vector_store):
|
||||
"""Test advanced search features of Clickzetta vector store."""
|
||||
# Prepare test data with more complex metadata
|
||||
documents = []
|
||||
embeddings = []
|
||||
for i in range(10):
|
||||
doc = Document(
|
||||
page_content=f"Document {i}: " + get_example_text(),
|
||||
metadata={
|
||||
"doc_id": f"adv_doc_{i}",
|
||||
"category": "technical" if i % 2 == 0 else "general",
|
||||
"document_id": f"doc_{i // 3}", # Group documents
|
||||
"importance": i,
|
||||
},
|
||||
)
|
||||
documents.append(doc)
|
||||
# Create varied embeddings
|
||||
embeddings.append([0.1 * i, 0.2 * i, 0.3 * i, 0.4 * i])
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test vector search with document filter
|
||||
query_vector = [0.5, 1.0, 1.5, 2.0]
|
||||
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
|
||||
assert len(results) > 0
|
||||
# All results should belong to doc_0 or doc_1 groups
|
||||
for result in results:
|
||||
assert result.metadata["document_id"] in ["doc_0", "doc_1"]
|
||||
|
||||
# Test score threshold
|
||||
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
|
||||
# Check that all results have a score above threshold
|
||||
for result in results:
|
||||
assert result.metadata.get("score", 0) >= 0.5
|
||||
|
||||
def test_clickzetta_batch_operations(self, vector_store):
|
||||
"""Test batch insertion operations."""
|
||||
# Prepare large batch of documents
|
||||
batch_size = 25
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
for i in range(batch_size):
|
||||
doc = Document(
|
||||
page_content=f"Batch document {i}: This is a test document for batch processing.",
|
||||
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
|
||||
)
|
||||
documents.append(doc)
|
||||
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
|
||||
|
||||
# Test batch insert
|
||||
vector_store.add_texts(documents=documents, embeddings=embeddings)
|
||||
|
||||
# Verify all documents were inserted
|
||||
for i in range(batch_size):
|
||||
assert vector_store.text_exists(f"batch_doc_{i}")
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("batch", "test_batch")
|
||||
|
||||
def test_clickzetta_edge_cases(self, vector_store):
|
||||
"""Test edge cases and error handling."""
|
||||
# Test empty operations
|
||||
vector_store.create(texts=[], embeddings=[])
|
||||
vector_store.add_texts(documents=[], embeddings=[])
|
||||
vector_store.delete_by_ids([])
|
||||
|
||||
# Test special characters in content
|
||||
special_doc = Document(
|
||||
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
|
||||
metadata={"doc_id": "special_doc", "test": "edge_case"},
|
||||
)
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
||||
|
||||
vector_store.add_texts(documents=[special_doc], embeddings=embeddings)
|
||||
assert vector_store.text_exists("special_doc")
|
||||
|
||||
# Test search with special characters
|
||||
results = vector_store.search_by_full_text("quotes", top_k=1)
|
||||
if results: # Full-text search might not be available
|
||||
assert len(results) > 0
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_ids(["special_doc"])
|
||||
|
||||
def test_clickzetta_full_text_search_modes(self, vector_store):
|
||||
"""Test different full-text search capabilities."""
|
||||
# Prepare documents with various language content
|
||||
documents = [
|
||||
Document(
|
||||
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Clickzetta provides powerful Lakehouse solutions",
|
||||
metadata={"doc_id": "en_doc_1", "lang": "english"},
|
||||
),
|
||||
Document(
|
||||
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
|
||||
),
|
||||
Document(
|
||||
page_content="Modern data architecture includes Lakehouse technology",
|
||||
metadata={"doc_id": "en_doc_2", "lang": "english"},
|
||||
),
|
||||
]
|
||||
|
||||
embeddings = [[0.1, 0.2, 0.3, 0.4] for _ in documents]
|
||||
|
||||
vector_store.create(texts=documents, embeddings=embeddings)
|
||||
|
||||
# Test Chinese full-text search
|
||||
results = vector_store.search_by_full_text("Lakehouse", top_k=4)
|
||||
assert len(results) >= 2 # Should find at least documents with "Lakehouse"
|
||||
|
||||
# Test English full-text search
|
||||
results = vector_store.search_by_full_text("solutions", top_k=2)
|
||||
assert len(results) >= 1 # Should find English documents with "solutions"
|
||||
|
||||
# Test mixed search
|
||||
results = vector_store.search_by_full_text("数据架构", top_k=2)
|
||||
assert len(results) >= 1 # Should find Chinese documents with this phrase
|
||||
|
||||
# Clean up
|
||||
vector_store.delete_by_metadata_field("lang", "chinese")
|
||||
vector_store.delete_by_metadata_field("lang", "english")
|
||||
@@ -1,165 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Clickzetta integration in Docker environment
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
def test_clickzetta_connection():
|
||||
"""Test direct connection to Clickzetta"""
|
||||
print("=== Testing direct Clickzetta connection ===")
|
||||
try:
|
||||
conn = connect(
|
||||
username=os.getenv("CLICKZETTA_USERNAME", "test_user"),
|
||||
password=os.getenv("CLICKZETTA_PASSWORD", "test_password"),
|
||||
instance=os.getenv("CLICKZETTA_INSTANCE", "test_instance"),
|
||||
service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
|
||||
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
|
||||
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
|
||||
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
|
||||
)
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Test basic connectivity
|
||||
cursor.execute("SELECT 1 as test")
|
||||
result = cursor.fetchone()
|
||||
print(f"✓ Connection test: {result}")
|
||||
|
||||
# Check if our test table exists
|
||||
cursor.execute("SHOW TABLES IN dify")
|
||||
tables = cursor.fetchall()
|
||||
print(f"✓ Existing tables: {[t[1] for t in tables if t[0] == 'dify']}")
|
||||
|
||||
# Check if test collection exists
|
||||
test_collection = "collection_test_dataset"
|
||||
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
|
||||
cursor.execute(f"DESCRIBE dify.{test_collection}")
|
||||
columns = cursor.fetchall()
|
||||
print(f"✓ Table structure for {test_collection}:")
|
||||
for col in columns:
|
||||
print(f" - {col[0]}: {col[1]}")
|
||||
|
||||
# Check for indexes
|
||||
cursor.execute(f"SHOW INDEXES IN dify.{test_collection}")
|
||||
indexes = cursor.fetchall()
|
||||
print(f"✓ Indexes on {test_collection}:")
|
||||
for idx in indexes:
|
||||
print(f" - {idx}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_dify_api():
|
||||
"""Test Dify API with Clickzetta backend"""
|
||||
print("\n=== Testing Dify API ===")
|
||||
base_url = "http://localhost:5001"
|
||||
|
||||
# Wait for API to be ready
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
except:
|
||||
if i == max_retries - 1:
|
||||
print("✗ Dify API is not responding")
|
||||
return False
|
||||
time.sleep(2)
|
||||
|
||||
# Check vector store configuration
|
||||
try:
|
||||
# This is a simplified check - in production, you'd use proper auth
|
||||
print("✓ Dify is configured to use Clickzetta as vector store")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ API test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_table_structure():
|
||||
"""Verify the table structure meets Dify requirements"""
|
||||
print("\n=== Verifying Table Structure ===")
|
||||
|
||||
expected_columns = {
|
||||
"id": "VARCHAR",
|
||||
"page_content": "VARCHAR",
|
||||
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
|
||||
"vector": "ARRAY<FLOAT>",
|
||||
}
|
||||
|
||||
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]
|
||||
|
||||
print("✓ Expected table structure:")
|
||||
for col, dtype in expected_columns.items():
|
||||
print(f" - {col}: {dtype}")
|
||||
|
||||
print("\n✓ Required metadata fields:")
|
||||
for field in expected_metadata_fields:
|
||||
print(f" - {field}")
|
||||
|
||||
print("\n✓ Index requirements:")
|
||||
print(" - Vector index (HNSW) on 'vector' column")
|
||||
print(" - Full-text index on 'page_content' (optional)")
|
||||
print(" - Functional index on metadata->>'$.doc_id' (recommended)")
|
||||
print(" - Functional index on metadata->>'$.document_id' (recommended)")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("Starting Clickzetta integration tests for Dify Docker\n")
|
||||
|
||||
tests = [
|
||||
("Direct Clickzetta Connection", test_clickzetta_connection),
|
||||
("Dify API Status", test_dify_api),
|
||||
("Table Structure Verification", verify_table_structure),
|
||||
]
|
||||
|
||||
results = []
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {test_name} crashed: {e}")
|
||||
results.append((test_name, False))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Test Summary:")
|
||||
print("=" * 50)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for test_name, success in results:
|
||||
status = "✅ PASSED" if success else "❌ FAILED"
|
||||
print(f"{test_name}: {status}")
|
||||
|
||||
print(f"\nTotal: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed! Clickzetta is ready for Dify Docker deployment.")
|
||||
print("\nNext steps:")
|
||||
print("1. Run: cd docker && docker-compose -f docker-compose.yaml -f docker-compose.clickzetta.yaml up -d")
|
||||
print("2. Access Dify at http://localhost:3000")
|
||||
print("3. Create a dataset and test vector storage with Clickzetta")
|
||||
return 0
|
||||
else:
|
||||
print("\n⚠️ Some tests failed. Please check the errors above.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -1,100 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import tablestore
|
||||
from _pytest.python_api import approx
|
||||
|
||||
from core.rag.datasource.vdb.tablestore.tablestore_vector import (
|
||||
TableStoreConfig,
|
||||
TableStoreVector,
|
||||
)
|
||||
from tests.integration_tests.vdb.test_vector_store import (
|
||||
AbstractVectorTest,
|
||||
get_example_document,
|
||||
get_example_text,
|
||||
setup_mock_redis,
|
||||
)
|
||||
|
||||
|
||||
class TableStoreVectorTest(AbstractVectorTest):
|
||||
def __init__(self, normalize_full_text_score: bool = False):
|
||||
super().__init__()
|
||||
self.vector = TableStoreVector(
|
||||
collection_name=self.collection_name,
|
||||
config=TableStoreConfig(
|
||||
endpoint=os.getenv("TABLESTORE_ENDPOINT"),
|
||||
instance_name=os.getenv("TABLESTORE_INSTANCE_NAME"),
|
||||
access_key_id=os.getenv("TABLESTORE_ACCESS_KEY_ID"),
|
||||
access_key_secret=os.getenv("TABLESTORE_ACCESS_KEY_SECRET"),
|
||||
normalize_full_text_bm25_score=normalize_full_text_score,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
)
|
||||
while True:
|
||||
search_response = self.vector._tablestore_client.search(
|
||||
table_name=self.vector._table_name,
|
||||
index_name=self.vector._index_name,
|
||||
search_query=tablestore.SearchQuery(query=tablestore.MatchAllQuery(), get_total_count=True, limit=0),
|
||||
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
|
||||
)
|
||||
if search_response.total_count == 1:
|
||||
break
|
||||
|
||||
def search_by_vector(self):
|
||||
super().search_by_vector()
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata["score"] > 0
|
||||
|
||||
docs = self.vector.search_by_vector(self.example_embedding, document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def search_by_full_text(self):
|
||||
super().search_by_full_text()
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[self.example_doc_id])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert docs[0].metadata["score"] == approx(0.1214, abs=1e-3)
|
||||
else:
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
# return none if normalize_full_text_score=true and score_threshold > 0
|
||||
docs = self.vector.search_by_full_text(
|
||||
get_example_text(), document_ids_filter=[self.example_doc_id], score_threshold=0.5
|
||||
)
|
||||
if self.vector._config.normalize_full_text_bm25_score:
|
||||
assert len(docs) == 0
|
||||
else:
|
||||
assert len(docs) == 1
|
||||
assert docs[0].metadata["doc_id"] == self.example_doc_id
|
||||
assert docs[0].metadata.get("score") is None
|
||||
|
||||
docs = self.vector.search_by_full_text(get_example_text(), document_ids_filter=[str(uuid.uuid4())])
|
||||
assert len(docs) == 0
|
||||
|
||||
def run_all_tests(self):
|
||||
try:
|
||||
self.vector.delete()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
def test_tablestore_vector(setup_mock_redis):
|
||||
TableStoreVectorTest().run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=True).run_all_tests()
|
||||
TableStoreVectorTest(normalize_full_text_score=False).run_all_tests()
|
||||
@@ -9,7 +9,6 @@ from core.ops.entities.config_entity import (
|
||||
OpikConfig,
|
||||
PhoenixConfig,
|
||||
TracingProviderEnum,
|
||||
WeaveConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +22,6 @@ class TestTracingProviderEnum:
|
||||
assert TracingProviderEnum.LANGFUSE == "langfuse"
|
||||
assert TracingProviderEnum.LANGSMITH == "langsmith"
|
||||
assert TracingProviderEnum.OPIK == "opik"
|
||||
assert TracingProviderEnum.WEAVE == "weave"
|
||||
assert TracingProviderEnum.ALIYUN == "aliyun"
|
||||
|
||||
|
||||
@@ -228,64 +226,6 @@ class TestOpikConfig:
|
||||
OpikConfig(url="ftp://custom.comet.com/opik/api/")
|
||||
|
||||
|
||||
class TestWeaveConfig:
|
||||
"""Test cases for WeaveConfig"""
|
||||
|
||||
def test_valid_config(self):
|
||||
"""Test valid Weave configuration"""
|
||||
config = WeaveConfig(
|
||||
api_key="test_key",
|
||||
entity="test_entity",
|
||||
project="test_project",
|
||||
endpoint="https://custom.wandb.ai",
|
||||
host="https://custom.host.com",
|
||||
)
|
||||
assert config.api_key == "test_key"
|
||||
assert config.entity == "test_entity"
|
||||
assert config.project == "test_project"
|
||||
assert config.endpoint == "https://custom.wandb.ai"
|
||||
assert config.host == "https://custom.host.com"
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly"""
|
||||
config = WeaveConfig(api_key="key", project="project")
|
||||
assert config.entity is None
|
||||
assert config.endpoint == "https://trace.wandb.ai"
|
||||
assert config.host is None
|
||||
|
||||
def test_missing_required_fields(self):
|
||||
"""Test that required fields are enforced"""
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(api_key="key")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
WeaveConfig(project="project")
|
||||
|
||||
def test_endpoint_validation_https_only(self):
|
||||
"""Test endpoint validation only allows HTTPS"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", endpoint="http://insecure.wandb.ai")
|
||||
|
||||
def test_host_validation_optional(self):
|
||||
"""Test host validation is optional but validates when provided"""
|
||||
config = WeaveConfig(api_key="key", project="project", host=None)
|
||||
assert config.host is None
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="")
|
||||
assert config.host == ""
|
||||
|
||||
config = WeaveConfig(api_key="key", project="project", host="https://valid.host.com")
|
||||
assert config.host == "https://valid.host.com"
|
||||
|
||||
def test_host_validation_invalid_scheme(self):
|
||||
"""Test host validation rejects invalid schemes when provided"""
|
||||
with pytest.raises(ValidationError, match="URL scheme must be one of"):
|
||||
WeaveConfig(api_key="key", project="project", host="ftp://invalid.host.com")
|
||||
|
||||
|
||||
class TestAliyunConfig:
|
||||
"""Test cases for AliyunConfig"""
|
||||
|
||||
@@ -379,7 +319,6 @@ class TestConfigIntegration:
|
||||
LangfuseConfig(public_key="public", secret_key="secret"),
|
||||
LangSmithConfig(api_key="key", project="project"),
|
||||
OpikConfig(api_key="key"),
|
||||
WeaveConfig(api_key="key", project="project"),
|
||||
AliyunConfig(license_key="test_license", endpoint="https://tracing-analysis-dc-hz.aliyuncs.com"),
|
||||
]
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = build_test_graph_init_params(
|
||||
graph_config=graph_config,
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", files=[]),
|
||||
user_inputs={"payload": "value"},
|
||||
),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def _build_node_config() -> NodeConfigDict:
|
||||
return NodeConfigDictAdapter.validate_python(
|
||||
{
|
||||
"id": "node-1",
|
||||
"data": {
|
||||
"type": TRIGGER_PLUGIN_NODE_TYPE,
|
||||
"title": "Trigger Event",
|
||||
"plugin_id": "plugin-id",
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"subscription_id": "subscription-id",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
"event_parameters": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_trigger_event_node_run_populates_trigger_info_metadata() -> None:
|
||||
init_params, runtime_state = _build_context(graph_config={})
|
||||
node = TriggerEventNode(
|
||||
id="node-1",
|
||||
config=_build_node_config(),
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
"plugin_unique_identifier": "plugin-unique-identifier",
|
||||
}
|
||||
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
19
api/tests/unit_tests/dify_graph/node_events/test_base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events.base import NodeRunResult
|
||||
|
||||
|
||||
def test_node_run_result_accepts_trigger_info_metadata() -> None:
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == {
|
||||
"provider_id": "provider-id",
|
||||
"event_name": "event-name",
|
||||
}
|
||||
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@@ -58,7 +58,6 @@ class TestOpsService:
|
||||
("phoenix", "https://app.phoenix.arize.com/projects/"),
|
||||
("langsmith", "https://smith.langchain.com/"),
|
||||
("opik", "https://www.comet.com/opik/"),
|
||||
("weave", "https://wandb.ai/"),
|
||||
("aliyun", "https://arms.console.aliyun.com/"),
|
||||
("tencent", "https://console.cloud.tencent.com/apm"),
|
||||
("mlflow", "http://localhost:5000/"),
|
||||
@@ -88,7 +87,7 @@ class TestOpsService:
|
||||
@patch("services.ops_service.db")
|
||||
@patch("services.ops_service.OpsTraceManager")
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
"provider", ["arize", "phoenix", "langsmith", "opik", "aliyun", "tencent", "mlflow", "databricks"]
|
||||
)
|
||||
def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider):
|
||||
# Arrange
|
||||
|
||||
5577
api/uv.lock
generated
5577
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -541,7 +541,7 @@ SUPABASE_URL=your-server-url
|
||||
# ------------------------------
|
||||
|
||||
# The type of vector store to use.
|
||||
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`, `hologres`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`, `hologres`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@@ -646,20 +646,6 @@ PGVECTO_RS_USER=postgres
|
||||
PGVECTO_RS_PASSWORD=difyai123456
|
||||
PGVECTO_RS_DATABASE=dify
|
||||
|
||||
# analyticdb configurations, only available when VECTOR_STORE is `analyticdb`
|
||||
ANALYTICDB_KEY_ID=your-ak
|
||||
ANALYTICDB_KEY_SECRET=your-sk
|
||||
ANALYTICDB_REGION_ID=cn-hangzhou
|
||||
ANALYTICDB_INSTANCE_ID=gp-ab123456
|
||||
ANALYTICDB_ACCOUNT=testaccount
|
||||
ANALYTICDB_PASSWORD=testpassword
|
||||
ANALYTICDB_NAMESPACE=dify
|
||||
ANALYTICDB_NAMESPACE_PASSWORD=difypassword
|
||||
ANALYTICDB_HOST=gp-test.aliyuncs.com
|
||||
ANALYTICDB_PORT=5432
|
||||
ANALYTICDB_MIN_CONNECTION=1
|
||||
ANALYTICDB_MAX_CONNECTION=5
|
||||
|
||||
# TiDB vector configurations, only available when VECTOR_STORE is `tidb_vector`
|
||||
TIDB_VECTOR_HOST=tidb
|
||||
TIDB_VECTOR_PORT=4000
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -247,18 +247,6 @@ x-shared-env: &shared-api-worker-env
|
||||
PGVECTO_RS_USER: ${PGVECTO_RS_USER:-postgres}
|
||||
PGVECTO_RS_PASSWORD: ${PGVECTO_RS_PASSWORD:-difyai123456}
|
||||
PGVECTO_RS_DATABASE: ${PGVECTO_RS_DATABASE:-dify}
|
||||
ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-your-ak}
|
||||
ANALYTICDB_KEY_SECRET: ${ANALYTICDB_KEY_SECRET:-your-sk}
|
||||
ANALYTICDB_REGION_ID: ${ANALYTICDB_REGION_ID:-cn-hangzhou}
|
||||
ANALYTICDB_INSTANCE_ID: ${ANALYTICDB_INSTANCE_ID:-gp-ab123456}
|
||||
ANALYTICDB_ACCOUNT: ${ANALYTICDB_ACCOUNT:-testaccount}
|
||||
ANALYTICDB_PASSWORD: ${ANALYTICDB_PASSWORD:-testpassword}
|
||||
ANALYTICDB_NAMESPACE: ${ANALYTICDB_NAMESPACE:-dify}
|
||||
ANALYTICDB_NAMESPACE_PASSWORD: ${ANALYTICDB_NAMESPACE_PASSWORD:-difypassword}
|
||||
ANALYTICDB_HOST: ${ANALYTICDB_HOST:-gp-test.aliyuncs.com}
|
||||
ANALYTICDB_PORT: ${ANALYTICDB_PORT:-5432}
|
||||
ANALYTICDB_MIN_CONNECTION: ${ANALYTICDB_MIN_CONNECTION:-1}
|
||||
ANALYTICDB_MAX_CONNECTION: ${ANALYTICDB_MAX_CONNECTION:-5}
|
||||
TIDB_VECTOR_HOST: ${TIDB_VECTOR_HOST:-tidb}
|
||||
TIDB_VECTOR_PORT: ${TIDB_VECTOR_PORT:-4000}
|
||||
TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-}
|
||||
@@ -370,11 +358,6 @@ x-shared-env: &shared-api-worker-env
|
||||
HUAWEI_CLOUD_PASSWORD: ${HUAWEI_CLOUD_PASSWORD:-admin}
|
||||
UPSTASH_VECTOR_URL: ${UPSTASH_VECTOR_URL:-https://xxx-vector.upstash.io}
|
||||
UPSTASH_VECTOR_TOKEN: ${UPSTASH_VECTOR_TOKEN:-dify}
|
||||
TABLESTORE_ENDPOINT: ${TABLESTORE_ENDPOINT:-https://instance-name.cn-hangzhou.ots.aliyuncs.com}
|
||||
TABLESTORE_INSTANCE_NAME: ${TABLESTORE_INSTANCE_NAME:-instance-name}
|
||||
TABLESTORE_ACCESS_KEY_ID: ${TABLESTORE_ACCESS_KEY_ID:-xxx}
|
||||
TABLESTORE_ACCESS_KEY_SECRET: ${TABLESTORE_ACCESS_KEY_SECRET:-xxx}
|
||||
TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE: ${TABLESTORE_NORMALIZE_FULLTEXT_BM25_SCORE:-false}
|
||||
CLICKZETTA_USERNAME: ${CLICKZETTA_USERNAME:-}
|
||||
CLICKZETTA_PASSWORD: ${CLICKZETTA_PASSWORD:-}
|
||||
CLICKZETTA_INSTANCE: ${CLICKZETTA_INSTANCE:-}
|
||||
@@ -728,7 +711,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -770,7 +753,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -809,7 +792,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.1
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -839,7 +822,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.1
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC, JSX } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
@@ -29,12 +29,11 @@ export type PopupProps = {
|
||||
langSmithConfig: LangSmithConfig | null
|
||||
langFuseConfig: LangFuseConfig | null
|
||||
opikConfig: OpikConfig | null
|
||||
weaveConfig: WeaveConfig | null
|
||||
aliyunConfig: AliyunConfig | null
|
||||
mlflowConfig: MLflowConfig | null
|
||||
databricksConfig: DatabricksConfig | null
|
||||
tencentConfig: TencentConfig | null
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig | MLflowConfig | DatabricksConfig) => void
|
||||
onConfigUpdated: (provider: TracingProvider, payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | TencentConfig | MLflowConfig | DatabricksConfig) => void
|
||||
onConfigRemoved: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
@@ -50,7 +49,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
langSmithConfig,
|
||||
langFuseConfig,
|
||||
opikConfig,
|
||||
weaveConfig,
|
||||
aliyunConfig,
|
||||
mlflowConfig,
|
||||
databricksConfig,
|
||||
@@ -78,7 +76,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
}
|
||||
}, [onChooseProvider])
|
||||
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => {
|
||||
const handleConfigUpdated = useCallback((payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => {
|
||||
onConfigUpdated(currentProvider!, payload)
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigUpdated])
|
||||
@@ -88,8 +86,8 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
hideConfigModal()
|
||||
}, [currentProvider, hideConfigModal, onConfigRemoved])
|
||||
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && weaveConfig && aliyunConfig && mlflowConfig && databricksConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !weaveConfig && !aliyunConfig && !mlflowConfig && !databricksConfig && !tencentConfig
|
||||
const providerAllConfigured = arizeConfig && phoenixConfig && langSmithConfig && langFuseConfig && opikConfig && aliyunConfig && mlflowConfig && databricksConfig && tencentConfig
|
||||
const providerAllNotConfigured = !arizeConfig && !phoenixConfig && !langSmithConfig && !langFuseConfig && !opikConfig && !aliyunConfig && !mlflowConfig && !databricksConfig && !tencentConfig
|
||||
|
||||
const switchContent = (
|
||||
<Switch
|
||||
@@ -164,19 +162,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
/>
|
||||
)
|
||||
|
||||
const weavePanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.weave}
|
||||
readOnly={readOnly}
|
||||
config={weaveConfig}
|
||||
hasConfigured={!!weaveConfig}
|
||||
onConfig={handleOnConfig(TracingProvider.weave)}
|
||||
isChosen={chosenProvider === TracingProvider.weave}
|
||||
onChoose={handleOnChoose(TracingProvider.weave)}
|
||||
key="weave-provider-panel"
|
||||
/>
|
||||
)
|
||||
|
||||
const aliyunPanel = (
|
||||
<ProviderPanel
|
||||
type={TracingProvider.aliyun}
|
||||
@@ -240,9 +225,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
if (opikConfig)
|
||||
configuredPanels.push(opikPanel)
|
||||
|
||||
if (weaveConfig)
|
||||
configuredPanels.push(weavePanel)
|
||||
|
||||
if (arizeConfig)
|
||||
configuredPanels.push(arizePanel)
|
||||
|
||||
@@ -282,9 +264,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
if (!opikConfig)
|
||||
notConfiguredPanels.push(opikPanel)
|
||||
|
||||
if (!weaveConfig)
|
||||
notConfiguredPanels.push(weavePanel)
|
||||
|
||||
if (!aliyunConfig)
|
||||
notConfiguredPanels.push(aliyunPanel)
|
||||
|
||||
@@ -319,7 +298,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
return aliyunConfig
|
||||
if (currentProvider === TracingProvider.tencent)
|
||||
return tencentConfig
|
||||
return weaveConfig
|
||||
return opikConfig
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -365,7 +344,6 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
{opikPanel}
|
||||
{mlflowPanel}
|
||||
{databricksPanel}
|
||||
{weavePanel}
|
||||
{arizePanel}
|
||||
{phoenixPanel}
|
||||
{aliyunPanel}
|
||||
|
||||
@@ -6,7 +6,6 @@ export const docURL = {
|
||||
[TracingProvider.langSmith]: 'https://docs.smith.langchain.com/',
|
||||
[TracingProvider.langfuse]: 'https://docs.langfuse.com',
|
||||
[TracingProvider.opik]: 'https://www.comet.com/docs/opik/integrations/dify',
|
||||
[TracingProvider.weave]: 'https://weave-docs.wandb.ai/',
|
||||
[TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680',
|
||||
[TracingProvider.mlflow]: 'https://mlflow.org/docs/latest/genai/',
|
||||
[TracingProvider.databricks]: 'https://docs.databricks.com/aws/en/mlflow3/genai/tracing/',
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import type { TracingStatus } from '@/models/app'
|
||||
import {
|
||||
RiArrowDownDoubleLine,
|
||||
@@ -12,7 +12,7 @@ import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
@@ -70,7 +70,6 @@ const Panel: FC = () => {
|
||||
[TracingProvider.langSmith]: LangsmithIcon,
|
||||
[TracingProvider.langfuse]: LangfuseIcon,
|
||||
[TracingProvider.opik]: OpikIcon,
|
||||
[TracingProvider.weave]: WeaveIcon,
|
||||
[TracingProvider.aliyun]: AliyunIcon,
|
||||
[TracingProvider.mlflow]: MlflowIcon,
|
||||
[TracingProvider.databricks]: DatabricksIcon,
|
||||
@@ -83,12 +82,11 @@ const Panel: FC = () => {
|
||||
const [langSmithConfig, setLangSmithConfig] = useState<LangSmithConfig | null>(null)
|
||||
const [langFuseConfig, setLangFuseConfig] = useState<LangFuseConfig | null>(null)
|
||||
const [opikConfig, setOpikConfig] = useState<OpikConfig | null>(null)
|
||||
const [weaveConfig, setWeaveConfig] = useState<WeaveConfig | null>(null)
|
||||
const [aliyunConfig, setAliyunConfig] = useState<AliyunConfig | null>(null)
|
||||
const [mlflowConfig, setMLflowConfig] = useState<MLflowConfig | null>(null)
|
||||
const [databricksConfig, setDatabricksConfig] = useState<DatabricksConfig | null>(null)
|
||||
const [tencentConfig, setTencentConfig] = useState<TencentConfig | null>(null)
|
||||
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig || mlflowConfig || databricksConfig || tencentConfig)
|
||||
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || arizeConfig || phoenixConfig || aliyunConfig || mlflowConfig || databricksConfig || tencentConfig)
|
||||
|
||||
const fetchTracingConfig = async () => {
|
||||
const getArizeConfig = async () => {
|
||||
@@ -116,11 +114,6 @@ const Panel: FC = () => {
|
||||
if (!OpikHasNotConfig)
|
||||
setOpikConfig(opikConfig as OpikConfig)
|
||||
}
|
||||
const getWeaveConfig = async () => {
|
||||
const { tracing_config: weaveConfig, has_not_configured: weaveHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.weave })
|
||||
if (!weaveHasNotConfig)
|
||||
setWeaveConfig(weaveConfig as WeaveConfig)
|
||||
}
|
||||
const getAliyunConfig = async () => {
|
||||
const { tracing_config: aliyunConfig, has_not_configured: aliyunHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.aliyun })
|
||||
if (!aliyunHasNotConfig)
|
||||
@@ -147,7 +140,6 @@ const Panel: FC = () => {
|
||||
getLangSmithConfig(),
|
||||
getLangFuseConfig(),
|
||||
getOpikConfig(),
|
||||
getWeaveConfig(),
|
||||
getAliyunConfig(),
|
||||
getMLflowConfig(),
|
||||
getDatabricksConfig(),
|
||||
@@ -168,8 +160,6 @@ const Panel: FC = () => {
|
||||
setLangFuseConfig(tracing_config as LangFuseConfig)
|
||||
else if (provider === TracingProvider.opik)
|
||||
setOpikConfig(tracing_config as OpikConfig)
|
||||
else if (provider === TracingProvider.weave)
|
||||
setWeaveConfig(tracing_config as WeaveConfig)
|
||||
else if (provider === TracingProvider.aliyun)
|
||||
setAliyunConfig(tracing_config as AliyunConfig)
|
||||
else if (provider === TracingProvider.tencent)
|
||||
@@ -187,8 +177,6 @@ const Panel: FC = () => {
|
||||
setLangFuseConfig(null)
|
||||
else if (provider === TracingProvider.opik)
|
||||
setOpikConfig(null)
|
||||
else if (provider === TracingProvider.weave)
|
||||
setWeaveConfig(null)
|
||||
else if (provider === TracingProvider.aliyun)
|
||||
setAliyunConfig(null)
|
||||
else if (provider === TracingProvider.mlflow)
|
||||
@@ -240,7 +228,6 @@ const Panel: FC = () => {
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
mlflowConfig={mlflowConfig}
|
||||
databricksConfig={databricksConfig}
|
||||
@@ -279,7 +266,6 @@ const Panel: FC = () => {
|
||||
langSmithConfig={langSmithConfig}
|
||||
langFuseConfig={langFuseConfig}
|
||||
opikConfig={opikConfig}
|
||||
weaveConfig={weaveConfig}
|
||||
aliyunConfig={aliyunConfig}
|
||||
mlflowConfig={mlflowConfig}
|
||||
databricksConfig={databricksConfig}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig, WeaveConfig } from './type'
|
||||
import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangSmithConfig, MLflowConfig, OpikConfig, PhoenixConfig, TencentConfig } from './type'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
@@ -23,10 +23,10 @@ import { TracingProvider } from './type'
|
||||
type Props = {
|
||||
appId: string
|
||||
type: TracingProvider
|
||||
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig | null
|
||||
payload?: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig | null
|
||||
onRemoved: () => void
|
||||
onCancel: () => void
|
||||
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => void
|
||||
onSaved: (payload: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig) => void
|
||||
onChosen: (provider: TracingProvider) => void
|
||||
}
|
||||
|
||||
@@ -64,14 +64,6 @@ const opikConfigTemplate = {
|
||||
workspace: '',
|
||||
}
|
||||
|
||||
const weaveConfigTemplate = {
|
||||
api_key: '',
|
||||
entity: '',
|
||||
project: '',
|
||||
endpoint: '',
|
||||
host: '',
|
||||
}
|
||||
|
||||
const aliyunConfigTemplate = {
|
||||
app_name: '',
|
||||
license_key: '',
|
||||
@@ -112,7 +104,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
const isEdit = !!payload
|
||||
const isAdd = !isEdit
|
||||
const [isSaving, setIsSaving] = useState(false)
|
||||
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | WeaveConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig>((() => {
|
||||
const [config, setConfig] = useState<ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | OpikConfig | AliyunConfig | MLflowConfig | DatabricksConfig | TencentConfig>((() => {
|
||||
if (isEdit)
|
||||
return payload
|
||||
|
||||
@@ -143,7 +135,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
else if (type === TracingProvider.tencent)
|
||||
return tencentConfigTemplate
|
||||
|
||||
return weaveConfigTemplate
|
||||
return opikConfigTemplate
|
||||
})())
|
||||
const [isShowRemoveConfirm, {
|
||||
setTrue: showRemoveConfirm,
|
||||
@@ -215,14 +207,6 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
// const postData = config as OpikConfig
|
||||
}
|
||||
|
||||
if (type === TracingProvider.weave) {
|
||||
const postData = config as WeaveConfig
|
||||
if (!errorMessage && !postData.api_key)
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: 'API Key' })
|
||||
if (!errorMessage && !postData.project)
|
||||
errorMessage = t('errorMsg.fieldRequired', { ns: 'common', field: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })
|
||||
}
|
||||
|
||||
if (type === TracingProvider.aliyun) {
|
||||
const postData = config as AliyunConfig
|
||||
if (!errorMessage && !postData.app_name)
|
||||
@@ -424,47 +408,6 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.weave && (
|
||||
<>
|
||||
<Field
|
||||
label="API Key"
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as WeaveConfig).api_key}
|
||||
onChange={handleConfigChange('api_key')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'API Key' })!}
|
||||
/>
|
||||
<Field
|
||||
label={t(`${I18N_PREFIX}.project`, { ns: 'app' })!}
|
||||
labelClassName="!text-sm"
|
||||
isRequired
|
||||
value={(config as WeaveConfig).project}
|
||||
onChange={handleConfigChange('project')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: t(`${I18N_PREFIX}.project`, { ns: 'app' }) })!}
|
||||
/>
|
||||
<Field
|
||||
label="Entity"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).entity}
|
||||
onChange={handleConfigChange('entity')}
|
||||
placeholder={t(`${I18N_PREFIX}.placeholder`, { ns: 'app', key: 'Entity' })!}
|
||||
/>
|
||||
<Field
|
||||
label="Endpoint"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).endpoint}
|
||||
onChange={handleConfigChange('endpoint')}
|
||||
placeholder="https://trace.wandb.ai/"
|
||||
/>
|
||||
<Field
|
||||
label="Host"
|
||||
labelClassName="!text-sm"
|
||||
value={(config as WeaveConfig).host}
|
||||
onChange={handleConfigChange('host')}
|
||||
placeholder="https://api.wandb.ai"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{type === TracingProvider.langSmith && (
|
||||
<>
|
||||
<Field
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { TracingProvider } from './type'
|
||||
@@ -30,7 +30,6 @@ const getIcon = (type: TracingProvider) => {
|
||||
[TracingProvider.langSmith]: LangsmithIconBig,
|
||||
[TracingProvider.langfuse]: LangfuseIconBig,
|
||||
[TracingProvider.opik]: OpikIconBig,
|
||||
[TracingProvider.weave]: WeaveIconBig,
|
||||
[TracingProvider.aliyun]: AliyunIconBig,
|
||||
[TracingProvider.mlflow]: MlflowIconBig,
|
||||
[TracingProvider.databricks]: DatabricksIconBig,
|
||||
|
||||
@@ -4,7 +4,6 @@ export enum TracingProvider {
|
||||
langSmith = 'langsmith',
|
||||
langfuse = 'langfuse',
|
||||
opik = 'opik',
|
||||
weave = 'weave',
|
||||
aliyun = 'aliyun',
|
||||
mlflow = 'mlflow',
|
||||
databricks = 'databricks',
|
||||
@@ -42,15 +41,6 @@ export type OpikConfig = {
|
||||
workspace: string
|
||||
url: string
|
||||
}
|
||||
|
||||
export type WeaveConfig = {
|
||||
api_key: string
|
||||
entity: string
|
||||
project: string
|
||||
endpoint: string
|
||||
host: string
|
||||
}
|
||||
|
||||
export type AliyunConfig = {
|
||||
app_name: string
|
||||
license_key: string
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,20 +0,0 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './WeaveIcon.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'WeaveIcon'
|
||||
|
||||
export default Icon
|
||||
File diff suppressed because one or more lines are too long
@@ -1,20 +0,0 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './WeaveIconBig.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'WeaveIconBig'
|
||||
|
||||
export default Icon
|
||||
@@ -17,5 +17,3 @@ export { default as PhoenixIconBig } from './PhoenixIconBig'
|
||||
export { default as TencentIcon } from './TencentIcon'
|
||||
export { default as TencentIconBig } from './TencentIconBig'
|
||||
export { default as TracingIcon } from './TracingIcon'
|
||||
export { default as WeaveIcon } from './WeaveIcon'
|
||||
export { default as WeaveIconBig } from './WeaveIconBig'
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "تتبع",
|
||||
"tracing.tracingDescription": "التقاط السياق الكامل لتنفيذ التطبيق، بما في ذلك مكالمات LLM، والسياق، والمطالبات، وطلبات HTTP، والمزيد، إلى منصة تتبع تابعة لجهة خارجية.",
|
||||
"tracing.view": "عرض",
|
||||
"tracing.weave.description": "Weave هي منصة مفتوحة المصدر لتقييم واختبار ومراقبة تطبيقات LLM.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "كل الأنواع",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Nachverfolgung",
|
||||
"tracing.tracingDescription": "Erfassung des vollständigen Kontexts der Anwendungsausführung, einschließlich LLM-Aufrufe, Kontext, Prompts, HTTP-Anfragen und mehr, auf einer Nachverfolgungsplattform von Drittanbietern.",
|
||||
"tracing.view": "Ansehen",
|
||||
"tracing.weave.description": "Weave ist eine Open-Source-Plattform zur Bewertung, Testung und Überwachung von LLM-Anwendungen.",
|
||||
"tracing.weave.title": "Weben",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "ALLE Typen",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Tracing",
|
||||
"tracing.tracingDescription": "Capture the full context of app execution, including LLM calls, context, prompts, HTTP requests, and more, to a third-party tracing platform.",
|
||||
"tracing.view": "View",
|
||||
"tracing.weave.description": "Weave is an open-source platform for evaluating, testing, and monitoring LLM applications.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "All Types ",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Rastreo",
|
||||
"tracing.tracingDescription": "Captura el contexto completo de la ejecución de la app, incluyendo llamadas LLM, contexto, prompts, solicitudes HTTP y más, en una plataforma de rastreo de terceros.",
|
||||
"tracing.view": "Vista",
|
||||
"tracing.weave.description": "Weave es una plataforma de código abierto para evaluar, probar y monitorear aplicaciones de LLM.",
|
||||
"tracing.weave.title": "Tejer",
|
||||
"typeSelector.advanced": "Flujo de chat",
|
||||
"typeSelector.agent": "Agente",
|
||||
"typeSelector.all": "Todos los tipos",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "ردیابی",
|
||||
"tracing.tracingDescription": "ثبت کامل متن اجرای برنامه، از جمله تماسهای LLM، متن، درخواستهای HTTP و بیشتر، به یک پلتفرم ردیابی شخص ثالث.",
|
||||
"tracing.view": "مشاهده",
|
||||
"tracing.weave.description": "ویو یک پلتفرم متن باز برای ارزیابی، آزمایش و نظارت بر برنامههای LLM است.",
|
||||
"tracing.weave.title": "بافندگی",
|
||||
"typeSelector.advanced": "چتفلو",
|
||||
"typeSelector.agent": "نماینده",
|
||||
"typeSelector.all": "همه انواع",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Traçage",
|
||||
"tracing.tracingDescription": "Capturez le contexte complet de l'exécution de l'application, y compris les appels LLM, le contexte, les prompts, les requêtes HTTP et plus encore, vers une plateforme de traçage tierce.",
|
||||
"tracing.view": "Vue",
|
||||
"tracing.weave.description": "Weave est une plateforme open-source pour évaluer, tester et surveiller les applications LLM.",
|
||||
"tracing.weave.title": "Tisser",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "Tous Types",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "ट्रेसिंग",
|
||||
"tracing.tracingDescription": "एप्लिकेशन निष्पादन का पूरा संदर्भ कैप्चर करें, जिसमें LLM कॉल, संदर्भ, प्रॉम्प्ट्स, HTTP अनुरोध और अधिक शामिल हैं, एक तृतीय-पक्ष ट्रेसिंग प्लेटफ़ॉर्म पर।",
|
||||
"tracing.view": "देखना",
|
||||
"tracing.weave.description": "वीव एक ओपन-सोर्स प्लेटफ़ॉर्म है जो LLM अनुप्रयोगों का मूल्यांकन, परीक्षण और निगरानी करने के लिए है।",
|
||||
"tracing.weave.title": "बुनना",
|
||||
"typeSelector.advanced": "चैटफ्लो",
|
||||
"typeSelector.agent": "एजेंट",
|
||||
"typeSelector.all": "सभी प्रकार",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Menelusuri",
|
||||
"tracing.tracingDescription": "Tangkap konteks lengkap eksekusi aplikasi, termasuk panggilan LLM, konteks, perintah, permintaan HTTP, dan lainnya, ke platform pelacakan pihak ketiga.",
|
||||
"tracing.view": "Melihat",
|
||||
"tracing.weave.description": "Weave adalah platform sumber terbuka untuk mengevaluasi, menguji, dan memantau aplikasi LLM.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Alur obrolan",
|
||||
"typeSelector.agent": "Agen",
|
||||
"typeSelector.all": "Semua Jenis",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Tracciamento",
|
||||
"tracing.tracingDescription": "Cattura il contesto completo dell'esecuzione dell'app, incluse chiamate LLM, contesto, prompt, richieste HTTP e altro, su una piattaforma di tracciamento di terze parti.",
|
||||
"tracing.view": "Vista",
|
||||
"tracing.weave.description": "Weave è una piattaforma open-source per valutare, testare e monitorare le applicazioni LLM.",
|
||||
"tracing.weave.title": "Intrecciare",
|
||||
"typeSelector.advanced": "Flusso di chat",
|
||||
"typeSelector.agent": "Agente",
|
||||
"typeSelector.all": "TUTTI I Tipi",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "追跡",
|
||||
"tracing.tracingDescription": "LLM の呼び出し、コンテキスト、プロンプト、HTTP リクエストなど、アプリケーション実行の全ての文脈をサードパーティのトレースプラットフォームで取り込みます。",
|
||||
"tracing.view": "見る",
|
||||
"tracing.weave.description": "Weave は、LLM アプリケーションを評価、テスト、および監視するためのオープンソースプラットフォームです。",
|
||||
"tracing.weave.title": "織る",
|
||||
"typeSelector.advanced": "チャットフロー",
|
||||
"typeSelector.agent": "エージェント",
|
||||
"typeSelector.all": "すべてのタイプ",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "추적",
|
||||
"tracing.tracingDescription": "LLM 호출, 컨텍스트, 프롬프트, HTTP 요청 등 앱 실행의 전체 컨텍스트를 제 3 자 추적 플랫폼에 캡처합니다.",
|
||||
"tracing.view": "보기",
|
||||
"tracing.weave.description": "Weave 는 LLM 애플리케이션을 평가하고 테스트하며 모니터링하기 위한 오픈 소스 플랫폼입니다.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "채팅 플로우",
|
||||
"typeSelector.agent": "에이전트",
|
||||
"typeSelector.all": "모든 종류",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Tracing",
|
||||
"tracing.tracingDescription": "Capture the full context of app execution, including LLM calls, context, prompts, HTTP requests, and more, to a third-party tracing platform.",
|
||||
"tracing.view": "View",
|
||||
"tracing.weave.description": "Weave is an open-source platform for evaluating, testing, and monitoring LLM applications.",
|
||||
"tracing.weave.title": "Weave",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "All Types ",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Śledzenie",
|
||||
"tracing.tracingDescription": "Przechwytywanie pełnego kontekstu wykonania aplikacji, w tym wywołań LLM, kontekstu, promptów, żądań HTTP i więcej, do platformy śledzenia stron trzecich.",
|
||||
"tracing.view": "Widok",
|
||||
"tracing.weave.description": "Weave to platforma open-source do oceny, testowania i monitorowania aplikacji LLM.",
|
||||
"tracing.weave.title": "Tkaj",
|
||||
"typeSelector.advanced": "Przepływ czatu",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "WSZYSTKIE Typy",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Rastreamento",
|
||||
"tracing.tracingDescription": "Captura o contexto completo da execução do aplicativo, incluindo chamadas LLM, contexto, prompts, solicitações HTTP e mais, para uma plataforma de rastreamento de terceiros.",
|
||||
"tracing.view": "Vista",
|
||||
"tracing.weave.description": "Weave é uma plataforma de código aberto para avaliar, testar e monitorar aplicações de LLM.",
|
||||
"tracing.weave.title": "Trançar",
|
||||
"typeSelector.advanced": "Fluxo de bate-papo",
|
||||
"typeSelector.agent": "Agente",
|
||||
"typeSelector.all": "Todos os Tipos",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Urmărire",
|
||||
"tracing.tracingDescription": "Captează contextul complet al execuției aplicației, inclusiv apelurile LLM, context, prompt-uri, cereri HTTP și altele, către o platformă de urmărire terță.",
|
||||
"tracing.view": "Vedere",
|
||||
"tracing.weave.description": "Weave este o platformă open-source pentru evaluarea, testarea și monitorizarea aplicațiilor LLM.",
|
||||
"tracing.weave.title": "Împletește",
|
||||
"typeSelector.advanced": "Fluxul de chat",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "TOATE Tipurile",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Отслеживание",
|
||||
"tracing.tracingDescription": "Запись полного контекста выполнения приложения, включая вызовы LLM, контекст, подсказки, HTTP-запросы и многое другое, на стороннюю платформу трассировки.",
|
||||
"tracing.view": "Просмотр",
|
||||
"tracing.weave.description": "Weave — это открытая платформа для оценки, тестирования и мониторинга приложений LLM.",
|
||||
"tracing.weave.title": "Ткать",
|
||||
"typeSelector.advanced": "Чатфлоу",
|
||||
"typeSelector.agent": "Агент",
|
||||
"typeSelector.all": "ВСЕ типы",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Sledenje",
|
||||
"tracing.tracingDescription": "Zajem celotnega konteksta izvajanja aplikacije, vključno s klici LLM, kontekstom, pozivi, zahtevami HTTP in še več, na platformo za sledenje tretje osebe.",
|
||||
"tracing.view": "Ogled",
|
||||
"tracing.weave.description": "Weave je odprtokodna platforma za vrednotenje, testiranje in spremljanje aplikacij LLM.",
|
||||
"tracing.weave.title": "Tkanje",
|
||||
"typeSelector.advanced": "Tok klepeta",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "VSE VRSTE",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "ติดตาม",
|
||||
"tracing.tracingDescription": "บันทึกบริบททั้งหมดของการดําเนินการของโปรเจกต์ รวมถึงการเรียก LLM, Prompt คําขอ HTTP และอื่นๆไปยังแพลตฟอร์มของของบุคคลที่สาม",
|
||||
"tracing.view": "มุมมอง",
|
||||
"tracing.weave.description": "Weave เป็นแพลตฟอร์มโอเพนซอร์สสำหรับการประเมินผล ทดสอบ และตรวจสอบแอปพลิเคชัน LLM",
|
||||
"tracing.weave.title": "ทอ",
|
||||
"typeSelector.advanced": "แชทโฟลว์",
|
||||
"typeSelector.agent": "ตัวแทน",
|
||||
"typeSelector.all": "ทุกประเภท",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "İzleme",
|
||||
"tracing.tracingDescription": "Uygulama yürütmesinin tam bağlamını, LLM çağrıları, bağlam, promptlar, HTTP istekleri ve daha fazlası dahil olmak üzere üçüncü taraf izleme platformuna yakalama.",
|
||||
"tracing.view": "Görünüm",
|
||||
"tracing.weave.description": "Weave, LLM uygulamalarını değerlendirmek, test etmek ve izlemek için açık kaynaklı bir platformdur.",
|
||||
"tracing.weave.title": "Dokuma",
|
||||
"typeSelector.advanced": "Sohbet akışı",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "All Types",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Відстеження",
|
||||
"tracing.tracingDescription": "Захоплення повного контексту виконання додатку, включаючи виклики LLM, контекст, підказки, HTTP-запити та інше, на сторонню платформу відстеження.",
|
||||
"tracing.view": "Вид",
|
||||
"tracing.weave.description": "Weave є платформою з відкритим кодом для оцінки, тестування та моніторингу LLM додатків.",
|
||||
"tracing.weave.title": "Ткати",
|
||||
"typeSelector.advanced": "Чат",
|
||||
"typeSelector.agent": "Агент",
|
||||
"typeSelector.all": "Усі типи",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "Theo dõi",
|
||||
"tracing.tracingDescription": "Ghi lại toàn bộ ngữ cảnh thực thi ứng dụng, bao gồm các cuộc gọi LLM, ngữ cảnh, lời nhắc, yêu cầu HTTP và nhiều hơn nữa, đến một nền tảng theo dõi của bên thứ ba.",
|
||||
"tracing.view": "Cảnh",
|
||||
"tracing.weave.description": "Weave là một nền tảng mã nguồn mở để đánh giá, thử nghiệm và giám sát các ứng dụng LLM.",
|
||||
"tracing.weave.title": "Dệt",
|
||||
"typeSelector.advanced": "Dòng trò chuyện",
|
||||
"typeSelector.agent": "Tác nhân",
|
||||
"typeSelector.all": "Tất cả loại",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "追踪",
|
||||
"tracing.tracingDescription": "捕获应用程序执行的完整上下文,包括 LLM 调用、上下文、提示、HTTP 请求等,发送到第三方跟踪平台。",
|
||||
"tracing.view": "查看",
|
||||
"tracing.weave.description": "Weave 是一个开源平台,用于评估、测试和监控大型语言模型应用程序。",
|
||||
"tracing.weave.title": "编织",
|
||||
"typeSelector.advanced": "Chatflow",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "所有类型",
|
||||
|
||||
@@ -265,8 +265,6 @@
|
||||
"tracing.tracing": "追蹤",
|
||||
"tracing.tracingDescription": "捕獲應用程式執行的完整上下文,包括 LLM 調用、上下文、提示、HTTP 請求等,到第三方追蹤平台。",
|
||||
"tracing.view": "查看",
|
||||
"tracing.weave.description": "Weave 是一個開源平台,用於評估、測試和監控大型語言模型應用程序。",
|
||||
"tracing.weave.title": "編織",
|
||||
"typeSelector.advanced": "聊天流",
|
||||
"typeSelector.agent": "Agent",
|
||||
"typeSelector.all": "所有類型",
|
||||
|
||||
@@ -9,7 +9,6 @@ import type {
|
||||
PhoenixConfig,
|
||||
TencentConfig,
|
||||
TracingProvider,
|
||||
WeaveConfig,
|
||||
} from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/type'
|
||||
import type { Dependency } from '@/app/components/plugins/types'
|
||||
import type { App, AppModeEnum, AppTemplate, SiteConfig } from '@/types/app'
|
||||
@@ -121,7 +120,7 @@ export type TracingStatus = {
|
||||
|
||||
export type TracingConfig = {
|
||||
tracing_provider: TracingProvider
|
||||
tracing_config: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | DatabricksConfig | MLflowConfig | OpikConfig | WeaveConfig | AliyunConfig | TencentConfig
|
||||
tracing_config: ArizeConfig | PhoenixConfig | LangSmithConfig | LangFuseConfig | DatabricksConfig | MLflowConfig | OpikConfig | AliyunConfig | TencentConfig
|
||||
}
|
||||
|
||||
export type WebhookTriggerResponse = {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.13.1",
|
||||
"version": "1.13.2",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.32.1",
|
||||
"imports": {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user