mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 04:45:09 +00:00
Merge branch 'main' into feat/rag-2
This commit is contained in:
@@ -122,7 +122,6 @@ class TencentVector(BaseVector):
|
||||
metric_type,
|
||||
params,
|
||||
)
|
||||
index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER)
|
||||
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER)
|
||||
index_sparse_vector = vdb_index.SparseIndex(
|
||||
name="sparse_vector",
|
||||
@@ -130,7 +129,7 @@ class TencentVector(BaseVector):
|
||||
index_type=enum.IndexType.SPARSE_INVERTED,
|
||||
metric_type=enum.MetricType.IP,
|
||||
)
|
||||
indexes = [index_id, index_vector, index_text, index_metadate]
|
||||
indexes = [index_id, index_vector, index_metadate]
|
||||
if self._enable_hybrid_search:
|
||||
indexes.append(index_sparse_vector)
|
||||
try:
|
||||
@@ -149,7 +148,7 @@ class TencentVector(BaseVector):
|
||||
index_metadate = vdb_index.FilterIndex(
|
||||
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
|
||||
)
|
||||
indexes = [index_id, index_vector, index_text, index_metadate]
|
||||
indexes = [index_id, index_vector, index_metadate]
|
||||
if self._enable_hybrid_search:
|
||||
indexes.append(index_sparse_vector)
|
||||
self._client.create_collection(
|
||||
|
||||
@@ -17,6 +17,7 @@ from core.workflow.entities.workflow_execution import (
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
@@ -67,7 +68,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
@@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import (
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
@@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
@@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, EndUser
|
||||
|
||||
|
||||
@@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]):
|
||||
if user:
|
||||
try:
|
||||
current_span = get_current_span()
|
||||
if isinstance(user, Account) and user.current_tenant_id:
|
||||
tenant_id = user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
tenant_id = user.tenant_id
|
||||
else:
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
return
|
||||
if current_span:
|
||||
current_span.set_attribute("service.tenant.id", tenant_id)
|
||||
|
||||
@@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
|
||||
"""
|
||||
Extract tenant_id from Account or EndUser object.
|
||||
|
||||
Args:
|
||||
user: Account or EndUser object
|
||||
|
||||
Returns:
|
||||
tenant_id string if available, None otherwise
|
||||
|
||||
Raises:
|
||||
ValueError: If user is neither Account nor EndUser
|
||||
"""
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
if isinstance(user, Account):
|
||||
return user.current_tenant_id
|
||||
elif isinstance(user, EndUser):
|
||||
return user.tenant_id
|
||||
else:
|
||||
raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
|
||||
|
||||
|
||||
def run(script):
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.variables import utils as variable_utils
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.helper import extract_tenant_id
|
||||
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
||||
@@ -364,12 +365,7 @@ class Workflow(Base):
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Get tenant_id from current_user (Account or EndUser)
|
||||
if isinstance(current_user, Account):
|
||||
# Account user
|
||||
tenant_id = current_user.current_tenant_id
|
||||
else:
|
||||
# EndUser
|
||||
tenant_id = current_user.tenant_id
|
||||
tenant_id = extract_tenant_id(current_user)
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
@@ -396,12 +392,7 @@ class Workflow(Base):
|
||||
return
|
||||
|
||||
# Get tenant_id from current_user (Account or EndUser)
|
||||
if isinstance(current_user, Account):
|
||||
# Account user
|
||||
tenant_id = current_user.current_tenant_id
|
||||
else:
|
||||
# EndUser
|
||||
tenant_id = current_user.tenant_id
|
||||
tenant_id = extract_tenant_id(current_user)
|
||||
|
||||
if not tenant_id:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.file import helpers as file_helpers
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import EndUser, UploadFile
|
||||
@@ -61,11 +62,7 @@ class FileService:
|
||||
# generate file key
|
||||
file_uuid = str(uuid.uuid4())
|
||||
|
||||
if isinstance(user, Account):
|
||||
current_tenant_id = user.current_tenant_id
|
||||
else:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
current_tenant_id = extract_tenant_id(user)
|
||||
|
||||
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||
DatasetMetadataBinding.document_id == document_id,
|
||||
).delete()
|
||||
db.session.commit()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
|
||||
280
api/tests/unit_tests/core/helper/test_encrypter.py
Normal file
280
api/tests/unit_tests/core/helper/test_encrypter.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import base64
|
||||
import binascii
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.helper.encrypter import (
|
||||
batch_decrypt_token,
|
||||
decrypt_token,
|
||||
encrypt_token,
|
||||
get_decrypt_decoding,
|
||||
obfuscated_token,
|
||||
)
|
||||
from libs.rsa import PrivkeyNotFoundError
|
||||
|
||||
|
||||
class TestObfuscatedToken:
|
||||
@pytest.mark.parametrize(
|
||||
("token", "expected"),
|
||||
[
|
||||
("", ""), # Empty token
|
||||
("1234567", "*" * 20), # Short token (<8 chars)
|
||||
("12345678", "*" * 20), # Boundary case (8 chars)
|
||||
("123456789abcdef", "123456" + "*" * 12 + "ef"), # Long token
|
||||
("abc!@#$%^&*()def", "abc!@#" + "*" * 12 + "ef"), # Special chars
|
||||
],
|
||||
)
|
||||
def test_obfuscation_logic(self, token, expected):
|
||||
"""Test core obfuscation logic for various token lengths"""
|
||||
assert obfuscated_token(token) == expected
|
||||
|
||||
def test_sensitive_data_protection(self):
|
||||
"""Ensure obfuscation never reveals full sensitive data"""
|
||||
token = "api_key_secret_12345"
|
||||
obfuscated = obfuscated_token(token)
|
||||
assert token not in obfuscated
|
||||
assert "*" * 12 in obfuscated
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
|
||||
assert "Tenant with id invalid-tenant not found" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDecryptToken:
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_successful_decryption(self, mock_decrypt):
|
||||
"""Test successful token decryption"""
|
||||
mock_decrypt.return_value = "decrypted_token"
|
||||
encrypted_data = base64.b64encode(b"encrypted_data").decode()
|
||||
|
||||
result = decrypt_token("tenant-123", encrypted_data)
|
||||
|
||||
assert result == "decrypted_token"
|
||||
mock_decrypt.assert_called_once_with(b"encrypted_data", "tenant-123")
|
||||
|
||||
def test_invalid_base64(self):
|
||||
"""Test handling of invalid base64 input"""
|
||||
with pytest.raises(binascii.Error):
|
||||
decrypt_token("tenant-123", "invalid_base64!!!")
|
||||
|
||||
|
||||
class TestBatchDecryptToken:
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decryption(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Test batch decryption functionality"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3"]
|
||||
tokens = [
|
||||
base64.b64encode(b"encrypted1").decode(),
|
||||
base64.b64encode(b"encrypted2").decode(),
|
||||
base64.b64encode(b"encrypted3").decode(),
|
||||
]
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3"]
|
||||
# Key should only be loaded once
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
|
||||
|
||||
class TestGetDecryptDecoding:
|
||||
@patch("extensions.ext_redis.redis_client.get")
|
||||
@patch("extensions.ext_storage.storage.load")
|
||||
def test_private_key_not_found(self, mock_storage_load, mock_redis_get):
|
||||
"""Test error when private key file doesn't exist"""
|
||||
mock_redis_get.return_value = None
|
||||
mock_storage_load.side_effect = FileNotFoundError()
|
||||
|
||||
with pytest.raises(PrivkeyNotFoundError) as exc_info:
|
||||
get_decrypt_decoding("tenant-123")
|
||||
|
||||
assert "Private key not found, tenant_id: tenant-123" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
mock_decrypt.return_value = original_token
|
||||
|
||||
# Test encryption
|
||||
encrypted = encrypt_token("tenant-123", original_token)
|
||||
|
||||
# Test decryption
|
||||
decrypted = decrypt_token("tenant-123", encrypted)
|
||||
|
||||
assert decrypted == original_token
|
||||
|
||||
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
encrypted = encrypt_token("tenant-123", "sensitive_data")
|
||||
|
||||
# Attempt to decrypt with different tenant should fail
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
mock_decrypt.side_effect = Exception("Invalid tenant key")
|
||||
|
||||
with pytest.raises(Exception, match="Invalid tenant key"):
|
||||
decrypt_token("different-tenant", encrypted)
|
||||
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_tampered_ciphertext_rejection(self, mock_decrypt):
|
||||
"""Detect and reject tampered ciphertext"""
|
||||
valid_encrypted = base64.b64encode(b"valid_data").decode()
|
||||
|
||||
# Tamper with ciphertext
|
||||
tampered_bytes = bytearray(base64.b64decode(valid_encrypted))
|
||||
tampered_bytes[0] ^= 0xFF
|
||||
tampered = base64.b64encode(bytes(tampered_bytes)).decode()
|
||||
|
||||
mock_decrypt.side_effect = Exception("Decryption error")
|
||||
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
|
||||
results = [encrypt_token("tenant-123", "token") for _ in range(3)]
|
||||
|
||||
# All results should be different
|
||||
assert len(set(results)) == 3
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Additional security-focused edge case tests"""
|
||||
|
||||
def test_should_handle_empty_string_in_obfuscation(self):
|
||||
"""Test handling of empty string in obfuscation"""
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
special_tokens = [
|
||||
"token\x00with\x00null", # Null bytes
|
||||
"token_with_emoji_😀🎉", # Unicode emoji
|
||||
"token\nwith\nnewlines", # Newlines
|
||||
"token\twith\ttabs", # Tabs
|
||||
"token_with_中文字符", # Chinese characters
|
||||
]
|
||||
|
||||
for token in special_tokens:
|
||||
result = encrypt_token("tenant-123", token)
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
mock_encrypt.side_effect = ValueError("Message too long for RSA key size")
|
||||
|
||||
# Create a token that would exceed RSA limits
|
||||
long_token = "x" * 300
|
||||
|
||||
with pytest.raises(ValueError, match="Message too long for RSA key size"):
|
||||
encrypt_token("tenant-123", long_token)
|
||||
|
||||
@patch("libs.rsa.get_decrypt_decoding")
|
||||
@patch("libs.rsa.decrypt_token_with_decoding")
|
||||
def test_batch_decrypt_loads_key_only_once(self, mock_decrypt_with_decoding, mock_get_decoding):
|
||||
"""Verify batch decryption optimization - loads key only once"""
|
||||
mock_rsa_key = MagicMock()
|
||||
mock_cipher_rsa = MagicMock()
|
||||
mock_get_decoding.return_value = (mock_rsa_key, mock_cipher_rsa)
|
||||
|
||||
# Test with multiple tokens
|
||||
mock_decrypt_with_decoding.side_effect = ["token1", "token2", "token3", "token4", "token5"]
|
||||
tokens = [base64.b64encode(f"encrypted{i}".encode()).decode() for i in range(5)]
|
||||
|
||||
result = batch_decrypt_token("tenant-123", tokens)
|
||||
|
||||
assert result == ["token1", "token2", "token3", "token4", "token5"]
|
||||
# Key should only be loaded once regardless of token count
|
||||
mock_get_decoding.assert_called_once_with("tenant-123")
|
||||
assert mock_decrypt_with_decoding.call_count == 5
|
||||
65
api/tests/unit_tests/libs/test_helper.py
Normal file
65
api/tests/unit_tests/libs/test_helper.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class TestExtractTenantId:
|
||||
"""Test cases for the extract_tenant_id utility function."""
|
||||
|
||||
def test_extract_tenant_id_from_account_with_tenant(self):
|
||||
"""Test extracting tenant_id from Account with current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
# Mock the current_tenant_id property
|
||||
account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})()
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id == "account-tenant-123"
|
||||
|
||||
def test_extract_tenant_id_from_account_without_tenant(self):
|
||||
"""Test extracting tenant_id from Account without current_tenant_id."""
|
||||
# Create a mock Account object
|
||||
account = Account()
|
||||
account._current_tenant = None
|
||||
|
||||
tenant_id = extract_tenant_id(account)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_from_enduser_with_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser with tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = "enduser-tenant-456"
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id == "enduser-tenant-456"
|
||||
|
||||
def test_extract_tenant_id_from_enduser_without_tenant(self):
|
||||
"""Test extracting tenant_id from EndUser without tenant_id."""
|
||||
# Create a mock EndUser object
|
||||
end_user = EndUser()
|
||||
end_user.tenant_id = None
|
||||
|
||||
tenant_id = extract_tenant_id(end_user)
|
||||
assert tenant_id is None
|
||||
|
||||
def test_extract_tenant_id_with_invalid_user_type(self):
|
||||
"""Test extracting tenant_id with invalid user type raises ValueError."""
|
||||
invalid_user = "not_a_user_object"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(invalid_user)
|
||||
|
||||
def test_extract_tenant_id_with_none_user(self):
|
||||
"""Test extracting tenant_id with None user raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(None)
|
||||
|
||||
def test_extract_tenant_id_with_dict_user(self):
|
||||
"""Test extracting tenant_id with dict user raises ValueError."""
|
||||
dict_user = {"id": "123", "tenant_id": "456"}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
@@ -9,6 +9,7 @@ from core.file.models import File
|
||||
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
|
||||
from core.variables.segments import IntegerSegment, Segment
|
||||
from factories.variable_factory import build_segment
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
|
||||
|
||||
@@ -43,7 +44,7 @@ def test_environment_variables():
|
||||
)
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
@@ -90,7 +91,7 @@ def test_update_environment_variables():
|
||||
)
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
@@ -136,7 +137,7 @@ def test_to_dict():
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
# Mock current_user as an EndUser
|
||||
mock_user = mock.Mock()
|
||||
mock_user = mock.Mock(spec=EndUser)
|
||||
mock_user.tenant_id = "tenant_id"
|
||||
|
||||
with (
|
||||
|
||||
Reference in New Issue
Block a user