mirror of
https://github.com/langgenius/dify.git
synced 2026-02-05 15:43:59 +00:00
Compare commits
1 Commits
refactor/t
...
review-mys
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beb3ce172d |
@@ -33,6 +33,18 @@ class SortOrder(StrEnum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
_METADATA_KEY_WHITELIST = {
|
||||
"annotation_id",
|
||||
"app_id",
|
||||
"batch",
|
||||
"dataset_id",
|
||||
"doc_hash",
|
||||
"doc_id",
|
||||
"document_id",
|
||||
"lang",
|
||||
"source",
|
||||
}
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@@ -45,10 +57,17 @@ class MyScaleVector(BaseVector):
|
||||
password=config.password,
|
||||
)
|
||||
self._client.command("SET allow_experimental_object_type=1")
|
||||
self._qualified_table = f"{self._config.database}.{self._collection_name}"
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MYSCALE
|
||||
|
||||
@classmethod
|
||||
def _validate_metadata_key(cls, key: str) -> str:
|
||||
if key not in cls._METADATA_KEY_WHITELIST:
|
||||
raise ValueError(f"Unsupported metadata key: {key!r}")
|
||||
return key
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
dimension = len(embeddings[0])
|
||||
self._create_collection(dimension)
|
||||
@@ -59,7 +78,7 @@ class MyScaleVector(BaseVector):
|
||||
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
|
||||
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
|
||||
sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
|
||||
CREATE TABLE IF NOT EXISTS {self._qualified_table}(
|
||||
id String,
|
||||
text String,
|
||||
vector Array(Float32),
|
||||
@@ -74,23 +93,21 @@ class MyScaleVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
ids = []
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
rows = []
|
||||
for i, doc in enumerate(documents):
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
rows.append(
|
||||
(
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata or {}),
|
||||
)
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
"""
|
||||
self._client.command(sql)
|
||||
if rows:
|
||||
self._client.insert(self._qualified_table, rows, column_names=columns)
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
@@ -98,49 +115,80 @@ class MyScaleVector(BaseVector):
|
||||
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
|
||||
results = self._client.query(
|
||||
f"SELECT id FROM {self._qualified_table} WHERE id = %(id)s LIMIT 1",
|
||||
parameters={"id": id},
|
||||
)
|
||||
return results.row_count > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
if not ids:
|
||||
return
|
||||
placeholders, params = self._build_in_params("id", ids)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
f"DELETE FROM {self._qualified_table} WHERE id IN ({placeholders})",
|
||||
parameters=params,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
rows = self._client.query(
|
||||
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
f"SELECT DISTINCT id FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
).result_rows
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
safe_key = self._validate_metadata_key(key)
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
|
||||
f"DELETE FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
|
||||
parameters={"value": value},
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
|
||||
return self._search(
|
||||
"TextSearch('enable_nlq=false')(text, %(query)s)",
|
||||
SortOrder.DESC,
|
||||
parameters={"query": query},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
@staticmethod
|
||||
def _build_in_params(prefix: str, values: list[str]) -> tuple[str, dict[str, str]]:
|
||||
params: dict[str, str] = {}
|
||||
placeholders = []
|
||||
for i, value in enumerate(values):
|
||||
name = f"{prefix}_{i}"
|
||||
placeholders.append(f"%({name})s")
|
||||
params[name] = value
|
||||
return ", ".join(placeholders), params
|
||||
|
||||
def _search(
|
||||
self,
|
||||
dist: str,
|
||||
order: SortOrder,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
where_clauses = []
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0:
|
||||
where_clauses.append(f"dist < {1 - score_threshold}")
|
||||
document_ids_filter = kwargs.get("document_ids_filter")
|
||||
query_params = dict(parameters or {})
|
||||
if document_ids_filter:
|
||||
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||
placeholders, params = self._build_in_params("document_id", document_ids_filter)
|
||||
where_clauses.append(f"metadata['document_id'] IN ({placeholders})")
|
||||
query_params.update(params)
|
||||
where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._qualified_table}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
"""
|
||||
try:
|
||||
@@ -150,14 +198,14 @@ class MyScaleVector(BaseVector):
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
for r in self._client.query(sql, parameters=query_params).named_results()
|
||||
]
|
||||
except Exception:
|
||||
logger.exception("Vector search operation failed")
|
||||
return []
|
||||
|
||||
def delete(self):
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
|
||||
self._client.command(f"DROP TABLE IF EXISTS {self._qualified_table}")
|
||||
|
||||
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
|
||||
Reference in New Issue
Block a user