Compare commits

...

10 Commits

Author SHA1 Message Date
jyong
194f79f035 add qdrant to tidb
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2025-12-17 20:10:36 +08:00
jyong
f552b543f6 add qdrant to tidb 2025-12-17 19:08:36 +08:00
jyong
03e3208c71 add qdrant to tidb 2025-12-17 19:02:25 +08:00
Jyong
201de91665 Update build-push.yml
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2025-12-16 10:44:45 +08:00
jyong
23750a35df Merge branch 'main' into feat/add-qdrant-to-tidb-migration 2025-12-16 10:43:13 +08:00
jyong
9cefbc60f2 Merge branch 'main' into feat/add-qdrant-to-tidb-migration 2025-12-15 16:46:50 +08:00
jyong
396a790c6d add qdrant to tidb 2025-12-15 16:10:35 +08:00
jyong
1baca71e37 migration command 2025-12-12 11:02:05 +08:00
jyong
67eb632f1a add qdrant migrate to tidb 2025-12-09 23:52:00 +08:00
jyong
5ae1f62daf add qdrant migrate to tidb 2025-12-09 18:26:03 +08:00
34 changed files with 796 additions and 100 deletions

View File

@@ -8,6 +8,7 @@ on:
- "build/**"
- "release/e-*"
- "hotfix/**"
- "feat/add-qdrant-to-tidb-migration"
tags:
- "*"

View File

@@ -32,7 +32,7 @@ from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@@ -175,54 +175,67 @@ def migrate_annotation_vector_database():
create_count = 0
skipped_count = 0
total_count = 0
page = 1
error_count = 0
per_page = 50
# Keyset pagination on AppAnnotationSetting (much smaller dataset than App table)
last_created_at = None
last_id = None
while True:
try:
# get apps info
per_page = 50
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
# Query AppAnnotationSetting directly instead of scanning all Apps
query = (
session.query(AppAnnotationSetting, App, DatasetCollectionBinding)
.join(App, App.id == AppAnnotationSetting.app_id)
.join(
DatasetCollectionBinding,
DatasetCollectionBinding.id == AppAnnotationSetting.collection_binding_id,
)
.where(App.status == "normal")
.order_by(App.created_at.desc())
)
# Apply keyset pagination condition
if last_created_at is not None and last_id is not None:
query = query.where(
sa.or_(
AppAnnotationSetting.created_at < last_created_at,
sa.and_(
AppAnnotationSetting.created_at == last_created_at,
AppAnnotationSetting.id < last_id,
),
)
)
results = (
query.order_by(AppAnnotationSetting.created_at.desc(), AppAnnotationSetting.id.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps:
if not results:
break
# Update cursor to the last record of current batch
last_created_at = results[-1][0].created_at
last_id = results[-1][0].id
except SQLAlchemyError:
raise
page += 1
for app in apps:
total_count = total_count + 1
for app_annotation_setting, app, dataset_collection_binding in results:
total_count += 1
click.echo(
f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
f"Processing the {total_count} app {app.id}. {create_count} created, {skipped_count} skipped."
)
try:
click.echo(f"Creating app annotation index: {app.id}")
if not app_annotation_setting:
skipped_count += 1
continue
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@@ -230,6 +243,12 @@ def migrate_annotation_vector_database():
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
index_struct=json.dumps({
"type": VectorType.QDRANT,
"vector_store": {
"class_prefix": dataset_collection_binding.collection_name,
}
}),
)
documents = []
if annotations:
@@ -240,31 +259,34 @@ def migrate_annotation_vector_database():
)
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
original_vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
click.echo(f"Migrating annotations for app: {app.id}.")
try:
vector.delete()
click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
raise e
if documents:
try:
click.echo(
click.style(
f"Creating vector index with {len(documents)} annotations for app {app.id}.",
fg="green",
)
original_annotarions = original_vector.search_by_metadata_field("app_id", app.id)
if original_annotarions:
new_dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
vector.create(documents)
new_vector = Vector(new_dataset)
new_vector.create_with_vectors(original_annotarions)
click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
else:
click.echo(click.style(f"No original annotations found for app {app.id}.", fg="green"))
skipped_count += 1
continue
except Exception as e:
click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red"))
raise e
click.echo(f"Successfully migrated app annotation {app.id}.")
create_count += 1
except Exception as e:
error_count += 1
click.echo(
click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red")
)
@@ -272,7 +294,7 @@ def migrate_annotation_vector_database():
click.echo(
click.style(
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.",
f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps. Errors {error_count} apps.",
fg="green",
)
)
@@ -285,6 +307,7 @@ def migrate_knowledge_vector_database():
click.echo(click.style("Starting vector database migration.", fg="green"))
create_count = 0
skipped_count = 0
error_count = 0
total_count = 0
vector_type = dify_config.VECTOR_STORE
upper_collection_vector_types = {
@@ -298,6 +321,7 @@ def migrate_knowledge_vector_database():
VectorType.OPENGAUSS,
VectorType.TABLESTORE,
VectorType.MATRIXONE,
VectorType.TIDB_ON_QDRANT,
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,
@@ -334,7 +358,13 @@ def migrate_knowledge_vector_database():
)
try:
click.echo(f"Creating dataset vector database index: {dataset.id}")
if not dataset.index_struct_dict:
skipped_count = skipped_count + 1
continue
if dataset.index_struct_dict:
if not dataset.index_struct_dict["type"]:
skipped_count = skipped_count + 1
continue
if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1
continue
@@ -361,23 +391,20 @@ def migrate_knowledge_vector_database():
else:
raise ValueError(f"Vector store {vector_type} is not supported.")
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
vector = Vector(dataset)
click.echo(f"Migrating dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
)
except Exception as e:
click.echo(
click.style(
f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
)
)
raise e
# try:
# vector.delete()
# click.echo(
# click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green")
# )
# except Exception as e:
# click.echo(
# click.style(
# f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red"
# )
# )
# raise e
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
@@ -390,29 +417,22 @@ def migrate_knowledge_vector_database():
documents = []
segments_count = 0
original_index_vector = Vector(dataset)
for dataset_document in dataset_documents:
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
segments_count = segments_count + 1
single_documents = original_index_vector.search_by_metadata_field(
"document_id", dataset_document.id
)
if single_documents:
documents.extend(single_documents)
segments_count += len(single_documents)
# update dataset index_struct_dict
index_struct_dict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
"original_type": dataset.index_struct_dict["type"],
}
dataset.index_struct = json.dumps(index_struct_dict)
if documents:
try:
click.echo(
@@ -422,23 +442,26 @@ def migrate_knowledge_vector_database():
fg="green",
)
)
vector.create(documents)
new_vector = Vector(dataset)
new_vector.create_with_vectors(documents)
click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green"))
except Exception as e:
click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red"))
raise e
db.session.add(dataset)
db.session.add(instance=dataset)
db.session.commit()
click.echo(f"Successfully migrated dataset {dataset.id}.")
create_count += 1
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red"))
error_count += 1
continue
click.echo(
click.style(
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green"
f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets. Errors {error_count} datasets.",
fg="green",
)
)

View File

@@ -313,6 +313,20 @@ class AlibabaCloudMySQLVector(BaseVector):
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_UNQUOTE(JSON_EXTRACT(meta, %s)) = %s",
(f"$.{key}", value),
)
docs = []
for record in cur:
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=record["text"], vector=record["embedding"], metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -58,6 +58,9 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_full_text(query, **kwargs)
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
return self.analyticdb_vector.search_by_metadata_field(key, value, **kwargs)
def delete(self):
self.analyticdb_vector.delete()

View File

@@ -305,6 +305,34 @@ class AnalyticdbVectorOpenAPI:
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
namespace=self.config.namespace,
namespace_password=self.config.namespace_password,
collection=self._collection_name,
include_values=True,
metrics=self.config.metrics,
vector=None, # ty: ignore [invalid-argument-type]
content=None, # ty: ignore [invalid-argument-type]
top_k=999999,
filter=f"metadata_->>'{key}' = '{value}'",
)
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
metadata = json.loads(match.metadata.get("metadata_"))
doc = Document(
page_content=match.metadata.get("page_content"),
vector=match.values.value if match.values else None,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self):
try:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

View File

@@ -270,6 +270,23 @@ class AnalyticdbVectorBySql:
documents.append(doc)
return documents
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(
f"SELECT id, embedding, page_content, metadata_ FROM {self.table_name} WHERE metadata_->>%s = %s",
(key, value),
)
documents = []
for record in cur:
_, vector, page_content, metadata = record
doc = Document(
page_content=page_content,
vector=vector,
metadata=metadata,
)
documents.append(doc)
return documents
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -198,6 +198,34 @@ class BaiduVector(BaseVector):
docs.append(doc)
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
# Escape double quotes in value to prevent injection
escaped_value = value.replace('"', '\\"')
filter = f'metadata["{key}"] = "{escaped_value}"'
res = self._db.table(self._collection_name).select(
projections=[VDBField.CONTENT_KEY, VDBField.METADATA_KEY, VDBField.VECTOR_KEY],
filter=filter,
)
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(VDBField.METADATA_KEY, {})
if isinstance(meta, str):
try:
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
meta = {}
vector = row_data.get(VDBField.VECTOR_KEY)
doc = Document(page_content=row_data.get(VDBField.CONTENT_KEY), vector=vector, metadata=meta)
docs.append(doc)
return docs
def delete(self):
try:
self._db.drop_table(table_name=self._collection_name)

View File

@@ -135,6 +135,30 @@ class ChromaVector(BaseVector):
# chroma does not support BM25 full text searching
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
# FIXME: fix the type error later
results = collection.get(
where={key: {"$eq": value}}, # type: ignore
include=["documents", "metadatas", "embeddings"],
)
if not results["ids"] or not results["documents"] or not results["metadatas"]:
return []
docs = []
for i, doc_id in enumerate(results["ids"]):
metadata = dict(results["metadatas"][i]) if results["metadatas"][i] else {}
vector = results["embeddings"][i] if results.get("embeddings") else None
doc = Document(
page_content=results["documents"][i],
vector=vector,
metadata=metadata,
)
docs.append(doc)
return docs
class ChromaVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:

View File

@@ -1025,6 +1025,38 @@ class ClickzettaVector(BaseVector):
with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
"""Search for documents by metadata field."""
# Check if table exists first
if not self._table_exists():
logger.warning(
"Table %s.%s does not exist, returning empty results",
self._config.schema_name,
self._table_name,
)
return []
# Use json_extract_string function for ClickZetta compatibility
search_sql = f"""
SELECT id, {Field.CONTENT_KEY}, {Field.METADATA_KEY}, {Field.VECTOR}
FROM {self._config.schema_name}.{self._table_name}
WHERE json_extract_string({Field.METADATA_KEY}, '$.{key}') = ?
"""
documents = []
with self.get_connection_context() as connection:
with connection.cursor() as cursor:
cursor.execute(search_sql, binding_params=[value])
results = cursor.fetchall()
for row in results:
metadata = self._parse_metadata(row[2], row[0])
vector = row[3] if len(row) > 3 else None
doc = Document(page_content=row[1], vector=vector, metadata=metadata)
documents.append(doc)
return documents
def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries."""
return ",".join(map(str, vector))

View File

@@ -325,6 +325,22 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query = f"""
SELECT text, metadata, embedding FROM
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
WHERE metadata.{key} = $value
"""
result = self._cluster.query(query, named_parameters={"value": value}).execute()
docs = []
for row in result:
text = row.get("text", "")
metadata = row.get("metadata", {})
vector = row.get("embedding")
doc = Document(page_content=text, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def delete(self):
manager = self._bucket.collections()
scopes = manager.get_all_scopes()

View File

@@ -249,6 +249,21 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query_str = {"query": {"match": {f"metadata.{key}": value}}}
results = self._client.search(index=self._collection_name, body=query_str, size=999999)
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"].get(Field.VECTOR),
metadata=hit["_source"][Field.METADATA_KEY],
)
)
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)

View File

@@ -149,6 +149,21 @@ class HuaweiCloudVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query_str = {"query": {"match": {f"metadata.{key}": value}}}
results = self._client.search(index=self._collection_name, body=query_str, size=999999)
docs = []
for hit in results["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY],
vector=hit["_source"].get(Field.VECTOR),
metadata=hit["_source"].get(Field.METADATA_KEY, {}),
)
)
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)

View File

@@ -326,6 +326,31 @@ class LindormVectorStore(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{ROUTING_FIELD}.keyword": self._routing}})
try:
params: dict[str, Any] = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params, size=999999)
except Exception:
logger.exception("Error executing metadata field search, query: %s", query)
raise
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY) or {}
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):

View File

@@ -217,6 +217,27 @@ class MatrixoneVector(BaseVector):
assert self.client is not None
self.client.delete()
@ensure_client
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
assert self.client is not None
results = self.client.query_by_metadata(filter={key: value})
docs = []
for result in results:
metadata = result.metadata
if isinstance(metadata, str):
metadata = json.loads(metadata)
vector = result.embedding if hasattr(result, "embedding") else None
docs.append(
Document(
page_content=result.document,
vector=vector,
metadata=metadata,
)
)
return docs
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:

View File

@@ -291,6 +291,30 @@ class MilvusVector(BaseVector):
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
"""
Search for documents by metadata field key and value.
"""
if not self._client.has_collection(self._collection_name):
return []
result = self._client.query(
collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=[Field.CONTENT_KEY, Field.METADATA_KEY, Field.VECTOR],
)
docs = []
for item in result:
metadata = item.get(Field.METADATA_KEY, {})
doc = Document(
page_content=item.get(Field.CONTENT_KEY, ""),
vector=item.get(Field.VECTOR),
metadata=metadata,
)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):

View File

@@ -156,6 +156,24 @@ class MyScaleVector(BaseVector):
logger.exception("Vector search operation failed")
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
sql = f"""
SELECT text, vector, metadata FROM {self._config.database}.{self._collection_name}
WHERE metadata['{key}']='{value}'
"""
try:
return [
Document(
page_content=r["text"],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
]
except Exception:
logger.exception("Metadata field search operation failed")
return []
def delete(self):
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")

View File

@@ -440,6 +440,42 @@ class OceanBaseVector(BaseVector):
logger.exception("Failed to delete collection '%s'", self._collection_name)
raise Exception(f"Failed to delete collection '{self._collection_name}'") from e
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
try:
import re
from sqlalchemy import text
# Validate key to prevent injection in JSON path
if not re.match(r"^[a-zA-Z0-9_.]+$", key):
raise ValueError(f"Invalid characters in metadata key: {key}")
# Use parameterized query to prevent SQL injection
sql = text(
f"SELECT text, metadata, vector FROM `{self._collection_name}` "
f"WHERE metadata->>'$.{key}' = :value"
)
with self._client.engine.connect() as conn:
result = conn.execute(sql, {"value": value})
rows = result.fetchall()
docs = []
for row in rows:
text_content, metadata, vector = row
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=text_content, vector=vector, metadata=metadata))
return docs
except Exception as e:
logger.exception(
"Failed to search by metadata field '%s'='%s' in collection '%s'",
key,
value,
self._collection_name,
)
raise Exception(f"Failed to search by metadata field '{key}'") from e
class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(

View File

@@ -222,6 +222,18 @@ class OpenGauss(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
(key, value),
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -236,6 +236,19 @@ class OpenSearchVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query = {"query": {"term": {f"{Field.METADATA_KEY}.{key}": value}}}
response = self._client.search(index=self._collection_name.lower(), body=query, size=999999)
docs = []
for hit in response["hits"]["hits"]:
metadata = hit["_source"].get(Field.METADATA_KEY) or {}
vector = hit["_source"].get(Field.VECTOR)
page_content = hit["_source"].get(Field.CONTENT_KEY)
doc = Document(page_content=page_content, vector=vector, metadata=metadata)
docs.append(doc)
return docs
def create_collection(
self, embeddings: list, metadatas: list[dict] | None = None, index_params: dict | None = None
):

View File

@@ -338,6 +338,20 @@ class OracleVector(BaseVector):
else:
return [Document(page_content="", metadata={})]
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_connection() as conn:
with conn.cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding FROM {self.table_name} WHERE JSON_VALUE(meta, '$.{key}') = :1",
(value,),
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
conn.close()
return docs
def delete(self):
with self._get_connection() as conn:
with conn.cursor() as cur:

View File

@@ -210,6 +210,19 @@ class PGVectoRS(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT text, meta, embedding FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'"
)
result = session.execute(select_statement).fetchall()
docs = []
for record in result:
doc = Document(page_content=record[0], vector=record[2], metadata=record[1])
docs.append(doc)
return docs
class PGVectoRSFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> PGVectoRS:

View File

@@ -242,6 +242,18 @@ class PGVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
(key, value),
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -199,6 +199,18 @@ class VastbaseVector(BaseVector):
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, embedding FROM {self.table_name} WHERE meta->>%s = %s",
(key, value),
)
docs = []
for record in cur:
metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
return docs
def delete(self):
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View File

@@ -434,9 +434,42 @@ class QdrantVector(BaseVector):
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=999999,
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
doc = self._document_from_scored_point(result, Field.CONTENT_KEY, Field.METADATA_KEY)
documents.append(doc)
return documents
except UnexpectedResponse as e:
if e.status_code == 404:
return []
raise e
@classmethod
def _document_from_scored_point(

View File

@@ -294,6 +294,27 @@ class RelytVector(BaseVector):
# milvus/zilliz/relyt doesn't support bm25 search
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
sql_query = f"""
SELECT document, metadata, embedding
FROM "{self._collection_name}"
WHERE metadata->>'{key}' = :value
"""
params = {"value": value}
with self.client.connect() as conn:
results = conn.execute(sql_text(sql_query), params).fetchall()
docs = []
for result in results:
doc = Document(
page_content=result.document,
vector=result.embedding,
metadata=result.metadata,
)
docs.append(doc)
return docs
class RelytVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> RelytVector:

View File

@@ -390,6 +390,45 @@ class TableStoreVector(BaseVector):
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
# Search using tags field which stores key=value pairs
tag_value = f"{key}={value}"
query = tablestore.SearchQuery(
tablestore.TermQuery(self._tags_field, tag_value),
limit=999999,
get_total_count=False,
)
search_response = self._tablestore_client.search(
table_name=self._table_name,
index_name=self._index_name,
search_query=query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
documents = []
for search_hit in search_response.search_hits:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR)
# TableStore stores vector as JSON string, need to parse it
vector = json.loads(vector_str) if vector_str else None
documents.append(
Document(
page_content=ots_column_map.get(Field.CONTENT_KEY) or "",
vector=vector,
metadata=metadata,
)
)
return documents
class TableStoreVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TableStoreVector:

View File

@@ -299,6 +299,28 @@ class TencentVector(BaseVector):
docs.append(doc)
return docs
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
filter = Filter(Filter.In(f"metadata.{key}", [value]))
res = self._client.query(
database_name=self._client_config.database,
collection_name=self.collection_name,
filter=filter,
retrieve_vector=True,
)
docs: list[Document] = []
if res is None or len(res) == 0:
return docs
for result in res:
meta = result.get(self.field_metadata)
if isinstance(meta, str):
meta = json.loads(meta)
vector = result.get(self.field_vector)
doc = Document(page_content=result.get(self.field_text), vector=vector, metadata=meta)
docs.append(doc)
return docs
def delete(self):
if self._has_collection():
self._client.drop_collection(

View File

@@ -393,6 +393,42 @@ class TidbOnQdrantVector(BaseVector):
return documents
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
try:
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=999999,
with_payload=True,
with_vectors=True,
)
results = response[0]
documents = []
for result in results:
if result:
metadata = result.payload.get(Field.METADATA_KEY) if result.payload else {}
page_content = result.payload.get(Field.CONTENT_KEY, "") if result.payload else ""
vector = result.vector if hasattr(result, "vector") else None
documents.append(Document(page_content=page_content, vector=vector, metadata=metadata))
return documents
except UnexpectedResponse as e:
if e.status_code == 404:
return []
raise e
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load()

View File

@@ -237,6 +237,21 @@ class TiDBVector(BaseVector):
# tidb doesn't support bm25 search
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
with Session(self._engine) as session:
select_statement = sql_text(f"""
SELECT meta, text, vector FROM {self._collection_name}
WHERE meta->>'$.{key}' = :value
""")
res = session.execute(select_statement, params={"value": value})
results = [(row[0], row[1], row[2]) for row in res]
docs = []
for meta, text, vector in results:
metadata = json.loads(meta)
docs.append(Document(page_content=text, vector=vector, metadata=metadata))
return docs
def delete(self):
with Session(self._engine) as session:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))

View File

@@ -117,6 +117,24 @@ class UpstashVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
query_result = self.index.query(
vector=[1.001 * i for i in range(self._get_index_dimension())],
include_metadata=True,
include_data=True,
include_vectors=True,
top_k=999999,
filter=f"{key} = '{value}'",
)
docs = []
for record in query_result:
metadata = record.metadata
text = record.data
vector = record.vector if hasattr(record, "vector") else None
if metadata is not None and text is not None:
docs.append(Document(page_content=text, vector=vector, metadata=metadata))
return docs
def delete(self):
self.index.reset()

View File

@@ -45,6 +45,10 @@ class BaseVector(ABC):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
@abstractmethod
def delete(self):
raise NotImplementedError

View File

@@ -7,7 +7,8 @@ from typing import Any
from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.entities.provider_configuration import ProviderModelBundle
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_type import VectorType
@@ -211,6 +212,35 @@ class Vector:
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_with_vectors(self, texts: list[Document], **kwargs):
"""
Create documents with vectors.
Args:
texts: List of documents.
**kwargs: Keyword arguments.
"""
embeddings = []
embedding_texts = []
for text in texts:
if text.vector:
embeddings.append(text.vector)
embedding_texts.append(text)
if embeddings and embedding_texts:
# batch create documents with vectors
start = time.time()
batch_size = 500
total_batches = len(embedding_texts) + batch_size - 1
for i in range(0, len(embedding_texts), batch_size):
batch = embedding_texts[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
batch_start = time.time()
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
logger.info("Embedding %s documents with vectors took %s s", len(embedding_texts), time.time() - start)
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
@@ -288,6 +318,9 @@ class Vector:
)
return self._vector_processor.search_by_vector(multimodal_vector, **kwargs)
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_metadata_field(key, value, **kwargs)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
@@ -298,16 +331,27 @@ class Vector:
collection_exist_cache_key = f"vector_indexing_{self._vector_processor.collection_name}"
redis_client.delete(collection_exist_cache_key)
def _get_embeddings(self) -> Embeddings:
def _get_embeddings(self) -> Embeddings | None:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
try:
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
return CacheEmbedding(embedding_model)
except Exception as e:
# return default embeddings
try:
default_embeddings = model_manager.get_default_model_instance(
tenant_id=self._dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
return CacheEmbedding(default_embeddings)
except Exception as e:
logger.info("Error getting default embeddings: %s", e)
return None
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():

View File

@@ -202,6 +202,30 @@ class VikingDBVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
# Query by metadata field using filter on group_id and matching metadata
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY, "conds": [self._group_id]},
limit=5000, # max value is 5000
)
if not results:
return []
docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
if isinstance(metadata, str):
metadata = json.loads(metadata)
if metadata.get(key) == value:
vector = result.fields.get(vdb_Field.VECTOR_KEY)
doc = Document(
page_content=result.fields.get(vdb_Field.CONTENT_KEY), vector=vector, metadata=metadata
)
docs.append(doc)
return docs
def delete(self):
if self._has_index():
self._client.drop_index(self._collection_name, self._index_name)

View File

@@ -442,6 +442,30 @@ class WeaviateVector(BaseVector):
return value.isoformat()
return value
def search_by_metadata_field(self, key: str, value: str, **kwargs: Any) -> list[Document]:
"""Searches for documents matching a specific metadata field value."""
if not self._client.collections.exists(self._collection_name):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, Field.TEXT_KEY.value})
res = col.query.fetch_objects(
filters=Filter.by_property(key).equal(value),
limit=999999,
return_properties=props,
include_vector=True,
)
docs: list[Document] = []
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
vector = obj.vector.get("default") if obj.vector else None
docs.append(Document(page_content=text, vector=vector, metadata=properties))
return docs
class WeaviateVectorFactory(AbstractVectorFactory):
"""Factory class for creating WeaviateVector instances."""