Compare commits

...

9 Commits

Author SHA1 Message Date
-LAN-
a0af8fb94c Merge branch 'main' into review-myscale-sqli 2026-03-30 15:54:07 +08:00
-LAN-
179b1efb10 Merge branch 'main' into review-myscale-sqli 2026-03-26 16:15:59 +08:00
-LAN-
d37772f81b test: resolve conflicts
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-03-26 15:17:48 +08:00
-LAN-
0103adc3aa Merge remote-tracking branch 'origin/main' into review-myscale-sqli
# Conflicts:
#	api/tests/unit_tests/core/rag/datasource/vdb/myscale/test_myscale_vector.py
2026-03-26 14:54:06 +08:00
-LAN-
f8ad55212e Merge remote-tracking branch 'origin/main' into review-myscale-sqli 2026-03-18 14:32:57 +08:00
-LAN-
d7b0a1679d Merge remote-tracking branch 'origin/main' into review-myscale-sqli 2026-03-15 07:08:27 +08:00
-LAN-
b923090e47 fix: parameterize myscale query vector and add regression test 2026-03-11 23:41:33 +08:00
-LAN-
e0436bf2db fix: preserve MyScale text content on insert 2026-03-11 23:41:33 +08:00
-LAN-
6a164265d6 Harden MyScale query parameterization 2026-03-11 23:41:30 +08:00
2 changed files with 115 additions and 43 deletions

View File

@@ -33,6 +33,18 @@ class SortOrder(StrEnum):
class MyScaleVector(BaseVector):
_METADATA_KEY_WHITELIST = {
"annotation_id",
"app_id",
"batch",
"dataset_id",
"doc_hash",
"doc_id",
"document_id",
"lang",
"source",
}
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
@@ -45,10 +57,17 @@ class MyScaleVector(BaseVector):
password=config.password,
)
self._client.command("SET allow_experimental_object_type=1")
self._qualified_table = f"{self._config.database}.{self._collection_name}"
def get_type(self) -> str:
return VectorType.MYSCALE
@classmethod
def _validate_metadata_key(cls, key: str) -> str:
if key not in cls._METADATA_KEY_WHITELIST:
raise ValueError(f"Unsupported metadata key: {key!r}")
return key
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
@@ -59,7 +78,7 @@ class MyScaleVector(BaseVector):
self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}")
fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else ""
sql = f"""
CREATE TABLE IF NOT EXISTS {self._config.database}.{self._collection_name}(
CREATE TABLE IF NOT EXISTS {self._qualified_table}(
id String,
text String,
vector Array(Float32),
@@ -74,73 +93,103 @@ class MyScaleVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = []
columns = ["id", "text", "vector", "metadata"]
values = []
rows = []
for i, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
rows.append(
(
doc_id,
doc.page_content,
embeddings[i],
json.dumps(doc.metadata or {}),
)
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}
"""
self._client.command(sql)
if rows:
self._client.insert(self._qualified_table, rows, column_names=columns)
return ids
@staticmethod
def escape_str(value: Any) -> str:
return "".join(" " if c in {"\\", "'"} else c for c in str(value))
def text_exists(self, id: str) -> bool:
results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'")
results = self._client.query(
f"SELECT id FROM {self._qualified_table} WHERE id = %(id)s LIMIT 1",
parameters={"id": id},
)
return results.row_count > 0
def delete_by_ids(self, ids: list[str]):
if not ids:
return
placeholders, params = self._build_in_params("id", ids)
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
f"DELETE FROM {self._qualified_table} WHERE id IN ({placeholders})",
parameters=params,
)
def get_ids_by_metadata_field(self, key: str, value: str):
safe_key = self._validate_metadata_key(key)
rows = self._client.query(
f"SELECT DISTINCT id FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
f"SELECT DISTINCT id FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
parameters={"value": value},
).result_rows
return [row[0] for row in rows]
def delete_by_metadata_field(self, key: str, value: str):
safe_key = self._validate_metadata_key(key)
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE metadata.{key}='{value}'"
f"DELETE FROM {self._qualified_table} WHERE metadata.{safe_key} = %(value)s",
parameters={"value": value},
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
return self._search(f"distance(vector, {str(query_vector)})", self._vec_order, **kwargs)
return self._search(
"distance(vector, %(query_vector)s)",
self._vec_order,
parameters={"query_vector": query_vector},
**kwargs,
)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)
return self._search(
"TextSearch('enable_nlq=false')(text, %(query)s)",
SortOrder.DESC,
parameters={"query": query},
**kwargs,
)
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
@staticmethod
def _build_in_params(prefix: str, values: list[str]) -> tuple[str, dict[str, str]]:
params: dict[str, str] = {}
placeholders = []
for i, value in enumerate(values):
name = f"{prefix}_{i}"
placeholders.append(f"%({name})s")
params[name] = value
return ", ".join(placeholders), params
def _search(
self,
dist: str,
order: SortOrder,
parameters: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Document]:
top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("top_k must be a positive integer")
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
where_clauses = []
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0:
where_clauses.append(f"dist < {1 - score_threshold}")
document_ids_filter = kwargs.get("document_ids_filter")
query_params = dict(parameters or {})
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
placeholders, params = self._build_in_params("document_id", document_ids_filter)
where_clauses.append(f"metadata['document_id'] IN ({placeholders})")
query_params.update(params)
where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
SELECT text, vector, metadata, {dist} as dist FROM {self._qualified_table}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
"""
try:
@@ -150,14 +199,14 @@ class MyScaleVector(BaseVector):
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
for r in self._client.query(sql, parameters=query_params).named_results()
]
except Exception:
logger.exception("Vector search operation failed")
return []
def delete(self):
self._client.command(f"DROP TABLE IF EXISTS {self._config.database}.{self._collection_name}")
self._client.command(f"DROP TABLE IF EXISTS {self._qualified_table}")
class MyScaleVectorFactory(AbstractVectorFactory):

View File

@@ -2,7 +2,7 @@ import importlib
import sys
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
import pytest
@@ -24,6 +24,7 @@ def _build_fake_clickhouse_connect_module():
class Client:
def __init__(self):
self.command = MagicMock()
self.insert = MagicMock()
self.query = MagicMock(return_value=QueryResult())
client = Client()
@@ -58,9 +59,11 @@ def _config(module):
)
def test_escape_str_replaces_backslash_and_quote(myscale_module):
escaped = myscale_module.MyScaleVector.escape_str(r"text\with'special")
assert escaped == "text with special"
def test_build_in_params_creates_named_placeholders(myscale_module):
placeholders, params = myscale_module.MyScaleVector._build_in_params("document_id", ["doc-1", "doc-2"])
assert placeholders == "%(document_id_0)s, %(document_id_1)s"
assert params == {"document_id_0": "doc-1", "document_id_1": "doc-2"}
def test_search_raises_for_invalid_top_k(myscale_module):
@@ -172,9 +175,11 @@ def test_add_texts_inserts_rows_and_returns_ids(myscale_module, monkeypatch):
ids = vector.add_texts(docs, [[0.1], [0.2], [0.3]])
assert ids == ["doc-a", "generated-uuid"]
sql = vector._client.command.call_args.args[0]
assert "INSERT INTO dify.collection_1" in sql
assert "te xt 1" in sql
vector._client.insert.assert_called_once()
insert_table, insert_rows = vector._client.insert.call_args.args[:2]
assert insert_table == "dify.collection_1"
assert insert_rows[0][1] == r"te'xt\1"
assert vector._client.insert.call_args.kwargs["column_names"] == ["id", "text", "vector", "metadata"]
def test_text_exists_and_metadata_operations(myscale_module):
@@ -198,7 +203,22 @@ def test_search_delegation_methods(myscale_module):
assert result_vector == ["result"]
assert result_text == ["result"]
assert vector._search.call_count == 2
vector._search.assert_has_calls(
[
call(
"distance(vector, %(query_vector)s)",
vector._vec_order,
parameters={"query_vector": [0.1, 0.2]},
top_k=2,
),
call(
"TextSearch('enable_nlq=false')(text, %(query)s)",
myscale_module.SortOrder.DESC,
parameters={"query": "hello"},
top_k=2,
),
]
)
def test_search_with_document_filter_and_exception(myscale_module):
@@ -215,7 +235,10 @@ def test_search_with_document_filter_and_exception(myscale_module):
)
assert len(docs) == 1
sql = vector._client.query.call_args.args[0]
assert "metadata['document_id'] in ('doc-1', 'doc-2')" in sql
assert "metadata['document_id'] IN (%(document_id_0)s, %(document_id_1)s)" in sql
query_params = vector._client.query.call_args.kwargs["parameters"]
assert query_params["document_id_0"] == "doc-1"
assert query_params["document_id_1"] == "doc-2"
vector._client.query.side_effect = RuntimeError("boom")
assert vector._search("distance(vector, [0.1])", myscale_module.SortOrder.ASC, top_k=1) == []