mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 23:04:12 +00:00
WIP: test(api): tests for truncation logic
This commit is contained in:
@@ -3,16 +3,27 @@ import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from extensions.ext_storage import storage
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models import db
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
@@ -175,6 +186,23 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
_node1_id = "test_loader_node_1"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_app_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_tenant_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(self):
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# yield session
|
||||
|
||||
# @pytest.fixture
|
||||
# def node_var(self, session):
|
||||
# pass
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
@@ -241,6 +269,246 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
||||
@pytest.mark.usefixtures("setup_account")
|
||||
def test_load_offloaded_variable_string_type_integration(self, setup_account):
|
||||
"""Test _load_offloaded_variable with string type using DraftVariableSaver for data creation."""
|
||||
|
||||
# Create a large string that will be offloaded
|
||||
test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB)
|
||||
large_string_segment = StringSegment(value=test_content)
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Use DraftVariableSaver to create offloaded variable (this mimics production)
|
||||
saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
node_type=NodeType.LLM, # Use a real node type
|
||||
node_execution_id=node_execution_id,
|
||||
user=setup_account,
|
||||
)
|
||||
|
||||
# Save the variable - this will trigger offloading due to large size
|
||||
saver.save(outputs={"offloaded_string_var": large_string_segment})
|
||||
session.commit()
|
||||
|
||||
# Now test loading using DraftVarLoader
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Load the variable using the standard workflow
|
||||
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 1
|
||||
loaded_variable = variables[0]
|
||||
assert loaded_variable.name == "offloaded_string_var"
|
||||
assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"]
|
||||
assert isinstance(loaded_variable.value, StringSegment)
|
||||
assert loaded_variable.value.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up - delete all draft variables for this app
|
||||
with Session(bind=db.engine) as session:
|
||||
service = WorkflowDraftVariableService(session)
|
||||
service.delete_workflow_variables(self._test_app_id)
|
||||
session.commit()
|
||||
|
||||
def test_load_offloaded_variable_object_type_integration(self):
|
||||
"""Test _load_offloaded_variable with object type using real storage and service."""
|
||||
|
||||
# Create a test object
|
||||
test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}}
|
||||
test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
content_bytes = test_json.encode()
|
||||
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content in storage
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create a variable file record
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.OBJECT,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
name="offloaded_object_var",
|
||||
value=build_segment({"truncated": True}),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Use the service method that properly preloads relationships
|
||||
service = WorkflowDraftVariableService(session)
|
||||
draft_vars = service.get_draft_variables_by_selectors(
|
||||
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
|
||||
)
|
||||
|
||||
assert len(draft_vars) == 1
|
||||
loaded_var = draft_vars[0]
|
||||
assert loaded_var.is_truncated()
|
||||
|
||||
# Create DraftVarLoader and test loading
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Test the _load_offloaded_variable method
|
||||
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
|
||||
|
||||
# Verify the results
|
||||
assert selector_tuple == ("test_offload_node", "offloaded_object_var")
|
||||
assert variable.id == loaded_var.id
|
||||
assert variable.name == "offloaded_object_var"
|
||||
assert variable.value.value == test_object
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
def test_load_variables_with_offloaded_variables_integration(self):
|
||||
"""Test load_variables method with mix of regular and offloaded variables using real storage."""
|
||||
# Create a regular variable (already exists from setUp)
|
||||
# Create offloaded variable content
|
||||
test_content = "This is offloaded content for integration test"
|
||||
content_bytes = test_content.encode()
|
||||
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create variable file
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.STRING,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_integration_node",
|
||||
name="offloaded_integration_var",
|
||||
value=build_segment("truncated"),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Test load_variables with both regular and offloaded variables
|
||||
# This method should handle the relationship preloading internally
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp
|
||||
["test_integration_node", "offloaded_integration_var"], # Offloaded variable
|
||||
]
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 2
|
||||
|
||||
# Find regular variable
|
||||
regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert regular_var.id == self._sys_var_id
|
||||
assert regular_var.value == "sys_value"
|
||||
|
||||
# Find offloaded variable
|
||||
offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node")
|
||||
assert offloaded_loaded_var.id == offloaded_var.id
|
||||
assert offloaded_loaded_var.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
@@ -272,7 +540,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
index=1,
|
||||
node_execution_id=self._node_exec_id,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_id=self._node_id,
|
||||
node_type=NodeType.LLM.value,
|
||||
title="Test Node",
|
||||
@@ -281,7 +549,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
outputs='{"test_var": "output_value", "other_var": "other_output"}',
|
||||
status="succeeded",
|
||||
elapsed_time=1.5,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
@@ -336,10 +604,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
)
|
||||
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin():
|
||||
persistent_session.add(
|
||||
self._workflow_node_execution,
|
||||
)
|
||||
|
||||
# Add all to database
|
||||
db.session.add_all(
|
||||
[
|
||||
self._workflow_node_execution,
|
||||
self._node_var_with_exec,
|
||||
self._node_var_without_exec,
|
||||
self._node_var_missing_exec,
|
||||
@@ -354,6 +626,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
self._node_var_missing_exec_id = self._node_var_missing_exec.id
|
||||
self._conv_var_id = self._conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.id == self._workflow_node_execution.id
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
@@ -380,9 +660,6 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
)
|
||||
return workflow
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
srv = self._get_test_srv()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from models import Tenant, db
|
||||
from models.model import App
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
|
||||
|
||||
@@ -212,3 +214,255 @@ class TestDeleteDraftVariablesIntegration:
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
assert "test/file2.json" in storage_keys_deleted
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
|
||||
213
api/tests/integration_tests/test_offload.py
Normal file
213
api/tests/integration_tests/test_offload.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(flask_req_ctx):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def test_offload(session, setup_account):
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
# step 1: create a UploadFile
|
||||
input_upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key="fake_storage_key",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=setup_account.id,
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
output_upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key="fake_storage_key",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=setup_account.id,
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
session.add(input_upload_file)
|
||||
session.add(output_upload_file)
|
||||
session.flush()
|
||||
|
||||
# step 2: create a WorkflowNodeExecutionModel
|
||||
node_execution = WorkflowNodeExecutionModel(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
index=1,
|
||||
node_id="test_node_id",
|
||||
node_type="test",
|
||||
title="Test Node",
|
||||
status="succeeded",
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by=setup_account.id,
|
||||
)
|
||||
session.add(node_execution)
|
||||
session.flush()
|
||||
|
||||
# step 3: create a WorkflowNodeExecutionOffload
|
||||
offload = WorkflowNodeExecutionOffload(
|
||||
id=uuidv7(),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_execution_id=node_execution.id,
|
||||
inputs_file_id=input_upload_file.id,
|
||||
outputs_file_id=output_upload_file.id,
|
||||
)
|
||||
session.add(offload)
|
||||
session.flush()
|
||||
|
||||
# Test preloading - this should work without raising LazyLoadError
|
||||
result = (
|
||||
session.query(WorkflowNodeExecutionModel)
|
||||
.options(
|
||||
selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||
joinedload(
|
||||
WorkflowNodeExecutionOffload.inputs_file,
|
||||
),
|
||||
joinedload(
|
||||
WorkflowNodeExecutionOffload.outputs_file,
|
||||
),
|
||||
)
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.id == node_execution.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Verify the relationships are properly loaded
|
||||
assert result is not None
|
||||
assert result.offload_data is not None
|
||||
assert result.offload_data.inputs_file is not None
|
||||
assert result.offload_data.inputs_file.id == input_upload_file.id
|
||||
assert result.offload_data.inputs_file.name == "test_file.txt"
|
||||
|
||||
# Test the computed properties
|
||||
assert result.inputs_truncated is True
|
||||
assert result.outputs_truncated is False
|
||||
assert False
|
||||
|
||||
|
||||
def _test_offload_save(session, setup_account):
|
||||
tenant_id = str(uuid.uuid4())
|
||||
app_id = str(uuid.uuid4())
|
||||
# step 1: create a UploadFile
|
||||
input_upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key="fake_storage_key",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=setup_account.id,
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
output_upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type="local",
|
||||
key="fake_storage_key",
|
||||
name="test_file.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=setup_account.id,
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
|
||||
node_execution_id = id = str(uuid.uuid4())
|
||||
|
||||
# step 3: create a WorkflowNodeExecutionOffload
|
||||
offload = WorkflowNodeExecutionOffload(
|
||||
id=uuidv7(),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
node_execution_id=node_execution_id,
|
||||
)
|
||||
offload.inputs_file = input_upload_file
|
||||
offload.outputs_file = output_upload_file
|
||||
|
||||
# step 2: create a WorkflowNodeExecutionModel
|
||||
node_execution = WorkflowNodeExecutionModel(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
index=1,
|
||||
node_id="test_node_id",
|
||||
node_type="test",
|
||||
title="Test Node",
|
||||
status="succeeded",
|
||||
created_by_role=CreatorUserRole.ACCOUNT.value,
|
||||
created_by=setup_account.id,
|
||||
)
|
||||
node_execution.offload_data = offload
|
||||
session.add(node_execution)
|
||||
session.flush()
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
"""
|
||||
2025-08-21 15:34:49,570 INFO sqlalchemy.engine.Engine BEGIN (implicit)
|
||||
2025-08-21 15:34:49,572 INFO sqlalchemy.engine.Engine INSERT INTO upload_files (id, tenant_id, storage_type, key, name, size, extension, mime_type, created_by_role, created_by, created_at, used, used_by, used_at, hash, source_url) VALUES (%(id__0)s::UUID, %(tenant_id__0)s::UUID, %(storage_type__0)s, %(k ... 410 characters truncated ... (created_at__1)s, %(used__1)s, %(used_by__1)s::UUID, %(used_at__1)s, %(hash__1)s, %(source_url__1)s)
|
||||
2025-08-21 15:34:49,572 INFO sqlalchemy.engine.Engine [generated in 0.00009s (insertmanyvalues) 1/1 (unordered)] {'created_at__0': datetime.datetime(2025, 8, 21, 15, 34, 49, 570482), 'id__0': '366621fa-4326-403e-8709-62e4d0de7367', 'storage_type__0': 'local', 'extension__0': 'txt', 'created_by__0': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'used_at__0': None, 'used_by__0': None, 'source_url__0': '', 'mime_type__0': 'text/plain', 'created_by_role__0': 'account', 'used__0': False, 'size__0': 1024, 'tenant_id__0': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'hash__0': None, 'key__0': 'fake_storage_key', 'name__0': 'test_file.txt', 'created_at__1': datetime.datetime(2025, 8, 21, 15, 34, 49, 570563), 'id__1': '3cdec641-a452-4df0-a9af-4a1a30c27ea5', 'storage_type__1': 'local', 'extension__1': 'txt', 'created_by__1': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'used_at__1': None, 'used_by__1': None, 'source_url__1': '', 'mime_type__1': 'text/plain', 'created_by_role__1': 'account', 'used__1': False, 'size__1': 1024, 'tenant_id__1': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'hash__1': None, 'key__1': 'fake_storage_key', 'name__1': 'test_file.txt'}
|
||||
2025-08-21 15:34:49,576 INFO sqlalchemy.engine.Engine INSERT INTO workflow_node_executions (id, tenant_id, app_id, workflow_id, triggered_from, workflow_run_id, index, predecessor_node_id, node_execution_id, node_id, node_type, title, inputs, process_data, outputs, status, error, execution_metadata, created_by_role, created_by, finished_at) VALUES (%(id)s::UUID, %(tenant_id)s::UUID, %(app_id)s::UUID, %(workflow_id)s::UUID, %(triggered_from)s, %(workflow_run_id)s::UUID, %(index)s, %(predecessor_node_id)s, %(node_execution_id)s, %(node_id)s, %(node_type)s, %(title)s, %(inputs)s, %(process_data)s, %(outputs)s, %(status)s, %(error)s, %(execution_metadata)s, %(created_by_role)s, %(created_by)s::UUID, %(finished_at)s) RETURNING workflow_node_executions.elapsed_time, workflow_node_executions.created_at
|
||||
2025-08-21 15:34:49,576 INFO sqlalchemy.engine.Engine [generated in 0.00019s] {'id': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'tenant_id': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'app_id': '79fa81c7-2760-40db-af54-74cb2fea2ce7', 'workflow_id': '95d341e3-381c-4c54-a383-f685a9741053', 'triggered_from': <WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN: 'workflow-run'>, 'workflow_run_id': None, 'index': 1, 'predecessor_node_id': None, 'node_execution_id': None, 'node_id': 'test_node_id', 'node_type': 'test', 'title': 'Test Node', 'inputs': None, 'process_data': None, 'outputs': None, 'status': 'succeeded', 'error': None, 'execution_metadata': None, 'created_by_role': 'account', 'created_by': 'ccc7657c-fb48-46bd-8f42-c837b14eab18', 'finished_at': None}
|
||||
2025-08-21 15:34:49,579 INFO sqlalchemy.engine.Engine INSERT INTO workflow_node_execution_offload (id, created_at, tenant_id, app_id, node_execution_id, inputs_file_id, outputs_file_id) VALUES (%(id)s::UUID, %(created_at)s, %(tenant_id)s::UUID, %(app_id)s::UUID, %(node_execution_id)s::UUID, %(inputs_file_id)s::UUID, %(outputs_file_id)s::UUID)
|
||||
2025-08-21 15:34:49,579 INFO sqlalchemy.engine.Engine [generated in 0.00016s] {'id': '0198cd44-b7ea-724b-9e1b-5f062a2ef45b', 'created_at': datetime.datetime(2025, 8, 21, 15, 34, 49, 579072), 'tenant_id': '4c1bbfc9-a28b-4d93-8987-45db78e3269c', 'app_id': '79fa81c7-2760-40db-af54-74cb2fea2ce7', 'node_execution_id': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'inputs_file_id': '366621fa-4326-403e-8709-62e4d0de7367', 'outputs_file_id': '3cdec641-a452-4df0-a9af-4a1a30c27ea5'}
|
||||
2025-08-21 15:34:49,581 INFO sqlalchemy.engine.Engine SELECT workflow_node_executions.id AS workflow_node_executions_id, workflow_node_executions.tenant_id AS workflow_node_executions_tenant_id, workflow_node_executions.app_id AS workflow_node_executions_app_id, workflow_node_executions.workflow_id AS workflow_node_executions_workflow_id, workflow_node_executions.triggered_from AS workflow_node_executions_triggered_from, workflow_node_executions.workflow_run_id AS workflow_node_executions_workflow_run_id, workflow_node_executions.index AS workflow_node_executions_index, workflow_node_executions.predecessor_node_id AS workflow_node_executions_predecessor_node_id, workflow_node_executions.node_execution_id AS workflow_node_executions_node_execution_id, workflow_node_executions.node_id AS workflow_node_executions_node_id, workflow_node_executions.node_type AS workflow_node_executions_node_type, workflow_node_executions.title AS workflow_node_executions_title, workflow_node_executions.inputs AS workflow_node_executions_inputs, workflow_node_executions.process_data AS workflow_node_executions_process_data, workflow_node_executions.outputs AS workflow_node_executions_outputs, workflow_node_executions.status AS workflow_node_executions_status, workflow_node_executions.error AS workflow_node_executions_error, workflow_node_executions.elapsed_time AS workflow_node_executions_elapsed_time, workflow_node_executions.execution_metadata AS workflow_node_executions_execution_metadata, workflow_node_executions.created_at AS workflow_node_executions_created_at, workflow_node_executions.created_by_role AS workflow_node_executions_created_by_role, workflow_node_executions.created_by AS workflow_node_executions_created_by, workflow_node_executions.finished_at AS workflow_node_executions_finished_at
|
||||
FROM workflow_node_executions
|
||||
WHERE workflow_node_executions.id = %(id_1)s::UUID
|
||||
LIMIT %(param_1)s
|
||||
2025-08-21 15:34:49,581 INFO sqlalchemy.engine.Engine [generated in 0.00009s] {'id_1': '9aac28b6-b6fc-4aea-abdf-21da3227e621', 'param_1': 1}
|
||||
2025-08-21 15:34:49,585 INFO sqlalchemy.engine.Engine SELECT workflow_node_execution_offload.node_execution_id AS workflow_node_execution_offload_node_execution_id, workflow_node_execution_offload.id AS workflow_node_execution_offload_id, workflow_node_execution_offload.created_at AS workflow_node_execution_offload_created_at, workflow_node_execution_offload.tenant_id AS workflow_node_execution_offload_tenant_id, workflow_node_execution_offload.app_id AS workflow_node_execution_offload_app_id, workflow_node_execution_offload.inputs_file_id AS workflow_node_execution_offload_inputs_file_id, workflow_node_execution_offload.outputs_file_id AS workflow_node_execution_offload_outputs_file_id
|
||||
FROM workflow_node_execution_offload
|
||||
WHERE workflow_node_execution_offload.node_execution_id IN (%(primary_keys_1)s::UUID)
|
||||
2025-08-21 15:34:49,585 INFO sqlalchemy.engine.Engine [generated in 0.00021s] {'primary_keys_1': '9aac28b6-b6fc-4aea-abdf-21da3227e621'}
|
||||
2025-08-21 15:34:49,587 INFO sqlalchemy.engine.Engine SELECT upload_files.id AS upload_files_id, upload_files.tenant_id AS upload_files_tenant_id, upload_files.storage_type AS upload_files_storage_type, upload_files.key AS upload_files_key, upload_files.name AS upload_files_name, upload_files.size AS upload_files_size, upload_files.extension AS upload_files_extension, upload_files.mime_type AS upload_files_mime_type, upload_files.created_by_role AS upload_files_created_by_role, upload_files.created_by AS upload_files_created_by, upload_files.created_at AS upload_files_created_at, upload_files.used AS upload_files_used, upload_files.used_by AS upload_files_used_by, upload_files.used_at AS upload_files_used_at, upload_files.hash AS upload_files_hash, upload_files.source_url AS upload_files_source_url
|
||||
FROM upload_files
|
||||
WHERE upload_files.id IN (%(primary_keys_1)s::UUID)
|
||||
2025-08-21 15:34:49,587 INFO sqlalchemy.engine.Engine [generated in 0.00012s] {'primary_keys_1': '3cdec641-a452-4df0-a9af-4a1a30c27ea5'}
|
||||
2025-08-21 15:34:49,588 INFO sqlalchemy.engine.Engine SELECT upload_files.id AS upload_files_id, upload_files.tenant_id AS upload_files_tenant_id, upload_files.storage_type AS upload_files_storage_type, upload_files.key AS upload_files_key, upload_files.name AS upload_files_name, upload_files.size AS upload_files_size, upload_files.extension AS upload_files_extension, upload_files.mime_type AS upload_files_mime_type, upload_files.created_by_role AS upload_files_created_by_role, upload_files.created_by AS upload_files_created_by, upload_files.created_at AS upload_files_created_at, upload_files.used AS upload_files_used, upload_files.used_by AS upload_files_used_by, upload_files.used_at AS upload_files_used_at, upload_files.hash AS upload_files_hash, upload_files.source_url AS upload_files_source_url
|
||||
FROM upload_files
|
||||
WHERE upload_files.id IN (%(primary_keys_1)s::UUID)
|
||||
2025-08-21 15:34:49,588 INFO sqlalchemy.engine.Engine [generated in 0.00010s] {'primary_keys_1': '366621fa-4326-403e-8709-62e4d0de7367'}
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
upload_file_id: 366621fa-4326-403e-8709-62e4d0de7367 3cdec641-a452-4df0-a9af-4a1a30c27ea5
|
||||
|
||||
workflow_node_executions_id: 9aac28b6-b6fc-4aea-abdf-21da3227e621
|
||||
|
||||
offload_id: 0198cd44-b7ea-724b-9e1b-5f062a2ef45b
|
||||
"""
|
||||
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
Integration tests for process_data truncation functionality.
|
||||
|
||||
These tests verify the end-to-end behavior of process_data truncation across
|
||||
the entire system, from database storage to API responses.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models import Account
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationTestData:
|
||||
"""Test data for truncation scenarios."""
|
||||
name: str
|
||||
process_data: dict[str, any]
|
||||
should_truncate: bool
|
||||
expected_storage_interaction: bool
|
||||
|
||||
|
||||
class TestProcessDataTruncationIntegration:
|
||||
"""Integration tests for process_data truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_db_engine(self):
|
||||
"""Create an in-memory SQLite database for testing."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
# Create minimal table structure for testing
|
||||
with engine.connect() as conn:
|
||||
# Create workflow_node_executions table
|
||||
conn.execute(text("""
|
||||
CREATE TABLE workflow_node_executions (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL,
|
||||
app_id TEXT NOT NULL,
|
||||
workflow_id TEXT NOT NULL,
|
||||
triggered_from TEXT NOT NULL,
|
||||
workflow_run_id TEXT,
|
||||
index_ INTEGER NOT NULL,
|
||||
predecessor_node_id TEXT,
|
||||
node_execution_id TEXT,
|
||||
node_id TEXT NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
inputs TEXT,
|
||||
process_data TEXT,
|
||||
outputs TEXT,
|
||||
status TEXT NOT NULL,
|
||||
error TEXT,
|
||||
elapsed_time REAL DEFAULT 0,
|
||||
execution_metadata TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
created_by_role TEXT NOT NULL,
|
||||
created_by TEXT NOT NULL,
|
||||
finished_at DATETIME
|
||||
)
|
||||
"""))
|
||||
|
||||
# Create workflow_node_execution_offload table
|
||||
conn.execute(text("""
|
||||
CREATE TABLE workflow_node_execution_offload (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
tenant_id TEXT NOT NULL,
|
||||
app_id TEXT NOT NULL,
|
||||
node_execution_id TEXT NOT NULL UNIQUE,
|
||||
inputs_file_id TEXT,
|
||||
outputs_file_id TEXT,
|
||||
process_data_file_id TEXT
|
||||
)
|
||||
"""))
|
||||
|
||||
# Create upload_files table (simplified)
|
||||
conn.execute(text("""
|
||||
CREATE TABLE upload_files (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL,
|
||||
storage_key TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
created_at DATETIME NOT NULL
|
||||
)
|
||||
"""))
|
||||
|
||||
conn.commit()
|
||||
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock account for testing."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-user-id"
|
||||
account.tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, in_memory_db_engine, mock_account):
|
||||
"""Create a repository instance for testing."""
|
||||
session_factory = sessionmaker(bind=in_memory_db_engine)
|
||||
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def create_test_execution(
|
||||
self,
|
||||
process_data: dict[str, any] | None = None,
|
||||
execution_id: str = "test-execution-id"
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a test execution with process_data."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
def get_truncation_test_data(self) -> list[TruncationTestData]:
|
||||
"""Get test data for various truncation scenarios."""
|
||||
return [
|
||||
TruncationTestData(
|
||||
name="small_process_data",
|
||||
process_data={"small": "data", "count": 5},
|
||||
should_truncate=False,
|
||||
expected_storage_interaction=False,
|
||||
),
|
||||
TruncationTestData(
|
||||
name="large_process_data",
|
||||
process_data={"large_field": "x" * 10000, "metadata": "info"},
|
||||
should_truncate=True,
|
||||
expected_storage_interaction=True,
|
||||
),
|
||||
TruncationTestData(
|
||||
name="complex_large_data",
|
||||
process_data={
|
||||
"logs": ["log entry"] * 500, # Large array
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
"details": {"description": "y" * 5000} # Large string
|
||||
},
|
||||
should_truncate=True,
|
||||
expected_storage_interaction=True,
|
||||
),
|
||||
]
|
||||
|
||||
@patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config')
|
||||
@patch('services.file_service.FileService.upload_file')
|
||||
@patch('extensions.ext_storage.storage')
|
||||
def test_end_to_end_process_data_truncation(
|
||||
self,
|
||||
mock_storage,
|
||||
mock_upload_file,
|
||||
mock_config,
|
||||
repository
|
||||
):
|
||||
"""Test end-to-end process_data truncation functionality."""
|
||||
# Configure truncation limits
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# Create large process_data that should be truncated
|
||||
large_process_data = {
|
||||
"large_field": "x" * 10000, # Exceeds string length limit
|
||||
"metadata": {"type": "processing", "timestamp": 1234567890}
|
||||
}
|
||||
|
||||
# Mock file upload
|
||||
mock_file = Mock()
|
||||
mock_file.id = "mock-process-data-file-id"
|
||||
mock_upload_file.return_value = mock_file
|
||||
|
||||
# Create and save execution
|
||||
execution = self.create_test_execution(process_data=large_process_data)
|
||||
repository.save(execution)
|
||||
|
||||
# Verify truncation occurred
|
||||
assert execution.process_data_truncated is True
|
||||
truncated_data = execution.get_truncated_process_data()
|
||||
assert truncated_data is not None
|
||||
assert truncated_data != large_process_data # Should be different due to truncation
|
||||
|
||||
# Verify file upload was called for process_data
|
||||
assert mock_upload_file.called
|
||||
upload_args = mock_upload_file.call_args
|
||||
assert "_process_data" in upload_args[1]["filename"]
|
||||
|
||||
@patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config')
|
||||
def test_small_process_data_no_truncation(self, mock_config, repository):
|
||||
"""Test that small process_data is not truncated."""
|
||||
# Configure truncation limits
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# Create small process_data
|
||||
small_process_data = {"small": "data", "count": 5}
|
||||
|
||||
execution = self.create_test_execution(process_data=small_process_data)
|
||||
repository.save(execution)
|
||||
|
||||
# Verify no truncation occurred
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
assert execution.get_response_process_data() == small_process_data
|
||||
|
||||
@pytest.mark.parametrize("test_data", [
|
||||
data for data in get_truncation_test_data(None)
|
||||
], ids=[data.name for data in get_truncation_test_data(None)])
|
||||
@patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config')
|
||||
@patch('services.file_service.FileService.upload_file')
|
||||
def test_various_truncation_scenarios(
|
||||
self,
|
||||
mock_upload_file,
|
||||
mock_config,
|
||||
test_data: TruncationTestData,
|
||||
repository
|
||||
):
|
||||
"""Test various process_data truncation scenarios."""
|
||||
# Configure truncation limits
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
if test_data.expected_storage_interaction:
|
||||
# Mock file upload for truncation scenarios
|
||||
mock_file = Mock()
|
||||
mock_file.id = f"file-{test_data.name}"
|
||||
mock_upload_file.return_value = mock_file
|
||||
|
||||
execution = self.create_test_execution(process_data=test_data.process_data)
|
||||
repository.save(execution)
|
||||
|
||||
# Verify truncation behavior matches expectations
|
||||
assert execution.process_data_truncated == test_data.should_truncate
|
||||
|
||||
if test_data.should_truncate:
|
||||
assert execution.get_truncated_process_data() is not None
|
||||
assert execution.get_truncated_process_data() != test_data.process_data
|
||||
assert mock_upload_file.called
|
||||
else:
|
||||
assert execution.get_truncated_process_data() is None
|
||||
assert execution.get_response_process_data() == test_data.process_data
|
||||
|
||||
@patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config')
|
||||
@patch('services.file_service.FileService.upload_file')
|
||||
@patch('extensions.ext_storage.storage')
|
||||
def test_load_truncated_execution_from_database(
|
||||
self,
|
||||
mock_storage,
|
||||
mock_upload_file,
|
||||
mock_config,
|
||||
repository,
|
||||
in_memory_db_engine
|
||||
):
|
||||
"""Test loading an execution with truncated process_data from database."""
|
||||
# Configure truncation
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# Create and save execution with large process_data
|
||||
large_process_data = {
|
||||
"large_field": "x" * 10000,
|
||||
"metadata": "info"
|
||||
}
|
||||
|
||||
# Mock file upload
|
||||
mock_file = Mock()
|
||||
mock_file.id = "process-data-file-id"
|
||||
mock_upload_file.return_value = mock_file
|
||||
|
||||
execution = self.create_test_execution(process_data=large_process_data)
|
||||
repository.save(execution)
|
||||
|
||||
# Mock storage load for reconstruction
|
||||
mock_storage.load.return_value = json.dumps(large_process_data).encode()
|
||||
|
||||
# Create a new repository instance to simulate fresh load
|
||||
session_factory = sessionmaker(bind=in_memory_db_engine)
|
||||
new_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"),
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Load executions from database
|
||||
executions = new_repository.get_by_workflow_run("test-run-id")
|
||||
|
||||
assert len(executions) == 1
|
||||
loaded_execution = executions[0]
|
||||
|
||||
# Verify that full data is loaded
|
||||
assert loaded_execution.process_data == large_process_data
|
||||
assert loaded_execution.process_data_truncated is True
|
||||
|
||||
# Verify truncated data for responses
|
||||
response_data = loaded_execution.get_response_process_data()
|
||||
assert response_data != large_process_data # Should be truncated version
|
||||
|
||||
def test_process_data_none_handling(self, repository):
|
||||
"""Test handling of None process_data."""
|
||||
execution = self.create_test_execution(process_data=None)
|
||||
repository.save(execution)
|
||||
|
||||
# Should handle None gracefully
|
||||
assert execution.process_data is None
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_response_process_data() is None
|
||||
|
||||
def test_empty_process_data_handling(self, repository):
|
||||
"""Test handling of empty process_data."""
|
||||
execution = self.create_test_execution(process_data={})
|
||||
repository.save(execution)
|
||||
|
||||
# Should handle empty dict gracefully
|
||||
assert execution.process_data == {}
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_response_process_data() == {}
|
||||
|
||||
|
||||
class TestProcessDataTruncationApiIntegration:
|
||||
"""Integration tests for API responses with process_data truncation."""
|
||||
|
||||
def test_api_response_includes_truncated_flag(self):
|
||||
"""Test that API responses include the process_data_truncated flag."""
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
|
||||
# Create execution with truncated process_data
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data={"large": "x" * 10000},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Set truncated data
|
||||
execution.set_truncated_process_data({"large": "[TRUNCATED]"})
|
||||
|
||||
# Create converter and event
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(
|
||||
spec=WorkflowAppGenerateEntity,
|
||||
app_config=Mock(tenant_id="test-tenant")
|
||||
)
|
||||
)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_data=Mock(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
# Generate response
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Verify response includes truncated flag and data
|
||||
assert response is not None
|
||||
assert response.data.process_data_truncated is True
|
||||
assert response.data.process_data == {"large": "[TRUNCATED]"}
|
||||
|
||||
# Verify response can be serialized
|
||||
response_dict = response.to_dict()
|
||||
assert "process_data_truncated" in response_dict["data"]
|
||||
assert response_dict["data"]["process_data_truncated"] is True
|
||||
|
||||
def test_workflow_run_fields_include_truncated_flag(self):
|
||||
"""Test that workflow run fields include process_data_truncated."""
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
|
||||
# Verify the field is included in the definition
|
||||
assert "process_data_truncated" in workflow_run_node_execution_fields
|
||||
|
||||
# The field should be a Boolean field
|
||||
field = workflow_run_node_execution_fields["process_data_truncated"]
|
||||
from flask_restful import fields
|
||||
assert isinstance(field, fields.Boolean)
|
||||
@@ -14,6 +14,8 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from alembic import command as alembic_command
|
||||
from alembic.config import Config
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import Engine, text
|
||||
@@ -345,6 +347,12 @@ def _create_app_with_containers() -> Flask:
|
||||
with db.engine.connect() as conn, conn.begin():
|
||||
conn.execute(text(_UUIDv7SQL))
|
||||
db.create_all()
|
||||
# migration_dir = _get_migration_dir()
|
||||
# alembic_config = Config()
|
||||
# alembic_config.config_file_name = str(migration_dir / "alembic.ini")
|
||||
# alembic_config.set_main_option("sqlalchemy.url", _get_engine_url(db.engine))
|
||||
# alembic_config.set_main_option("script_location", str(migration_dir))
|
||||
# alembic_command.upgrade(revision="head", config=alembic_config)
|
||||
logger.info("Database schema created successfully")
|
||||
|
||||
logger.info("Flask application configured and ready for testing")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, NamedTuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import marshal
|
||||
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
@@ -9,11 +11,14 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
_serialize_full_content,
|
||||
)
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||
|
||||
_TEST_APP_ID = "test_app_id"
|
||||
@@ -21,6 +26,54 @@ _TEST_NODE_EXEC_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableFields:
|
||||
def test_serialize_full_content(self):
|
||||
"""Test that _serialize_full_content uses pre-loaded relationships."""
|
||||
# Create mock objects with relationships pre-loaded
|
||||
mock_variable_file = MagicMock(spec=WorkflowDraftVariableFile)
|
||||
mock_variable_file.size = 100000
|
||||
mock_variable_file.length = 50
|
||||
mock_variable_file.value_type = SegmentType.OBJECT
|
||||
mock_variable_file.upload_file_id = "test-upload-file-id"
|
||||
|
||||
mock_variable = MagicMock(spec=WorkflowDraftVariable)
|
||||
mock_variable.file_id = "test-file-id"
|
||||
mock_variable.variable_file = mock_variable_file
|
||||
|
||||
# Mock the file helpers
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
|
||||
# Call the function
|
||||
result = _serialize_full_content(mock_variable)
|
||||
|
||||
# Verify it returns the expected structure
|
||||
assert result is not None
|
||||
assert result["size_bytes"] == 100000
|
||||
assert result["length"] == 50
|
||||
assert result["value_type"] == "object"
|
||||
assert "download_url" in result
|
||||
assert result["download_url"] == "http://example.com/signed-url"
|
||||
|
||||
# Verify it used the pre-loaded relationships (no database queries)
|
||||
mock_file_helpers.get_signed_file_url.assert_called_once_with("test-upload-file-id", as_attachment=True)
|
||||
|
||||
def test_serialize_full_content_handles_none_cases(self):
|
||||
"""Test that _serialize_full_content handles None cases properly."""
|
||||
|
||||
# Test with no file_id
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.file_id = None
|
||||
result = _serialize_full_content(draft_var)
|
||||
assert result is None
|
||||
|
||||
def test_serialize_full_content_should_raises_when_file_id_exists_but_file_is_none(self):
|
||||
# Test with no file_id
|
||||
draft_var = WorkflowDraftVariable()
|
||||
draft_var.file_id = str(uuid.uuid4())
|
||||
draft_var.variable_file = None
|
||||
with pytest.raises(AssertionError):
|
||||
result = _serialize_full_content(draft_var)
|
||||
|
||||
def test_conversation_variable(self):
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||
@@ -39,12 +92,14 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "number",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = 1
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_create_sys_variable(self):
|
||||
@@ -70,11 +125,13 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "string",
|
||||
"edited": True,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = "a"
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable(self):
|
||||
@@ -100,14 +157,65 @@ class TestWorkflowDraftVariableFields:
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
expected_with_value["full_content"] = None
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
def test_node_variable_with_file(self):
|
||||
node_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=_TEST_APP_ID,
|
||||
node_id="test_node",
|
||||
name="node_var",
|
||||
value=build_segment([1, "a"]),
|
||||
visible=False,
|
||||
node_execution_id=_TEST_NODE_EXEC_ID,
|
||||
)
|
||||
|
||||
node_var.id = str(uuid.uuid4())
|
||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
)
|
||||
node_var.variable_file = variable_file
|
||||
node_var.file_id = variable_file.id
|
||||
|
||||
expected_without_value: OrderedDict[str, Any] = OrderedDict(
|
||||
{
|
||||
"id": str(node_var.id),
|
||||
"type": node_var.get_variable_type().value,
|
||||
"name": "node_var",
|
||||
"description": "",
|
||||
"selector": ["test_node", "node_var"],
|
||||
"value_type": "array[any]",
|
||||
"edited": True,
|
||||
"visible": False,
|
||||
"is_truncated": True,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
expected_with_value["value"] = [1, "a"]
|
||||
expected_with_value["full_content"] = {
|
||||
"size_bytes": 1024,
|
||||
"value_type": "array[string]",
|
||||
"length": 10,
|
||||
"download_url": "http://example.com/signed-url",
|
||||
}
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableList:
|
||||
def test_workflow_draft_variable_list(self):
|
||||
@@ -135,6 +243,7 @@ class TestWorkflowDraftVariableList:
|
||||
"value_type": "string",
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"is_truncated": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Unit tests for WorkflowResponseConverter focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeRetryEvent, QueueNodeSucceededEvent
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataResponseScenario:
|
||||
"""Test scenario for process_data in responses."""
|
||||
|
||||
name: str
|
||||
original_process_data: dict[str, Any] | None
|
||||
truncated_process_data: dict[str, Any] | None
|
||||
expected_response_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
|
||||
|
||||
class TestWorkflowResponseConverterCenarios:
|
||||
"""Test process_data truncation in WorkflowResponseConverter."""
|
||||
|
||||
def create_mock_generate_entity(self) -> WorkflowAppGenerateEntity:
|
||||
"""Create a mock WorkflowAppGenerateEntity."""
|
||||
mock_entity = Mock(spec=WorkflowAppGenerateEntity)
|
||||
mock_app_config = Mock()
|
||||
mock_app_config.tenant_id = "test-tenant-id"
|
||||
mock_entity.app_config = mock_app_config
|
||||
return mock_entity
|
||||
|
||||
def create_workflow_response_converter(self) -> WorkflowResponseConverter:
|
||||
"""Create a WorkflowResponseConverter for testing."""
|
||||
mock_entity = self.create_mock_generate_entity()
|
||||
return WorkflowResponseConverter(application_generate_entity=mock_entity)
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
truncated_process_data: dict[str, Any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution for testing."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(truncated_process_data)
|
||||
|
||||
return execution
|
||||
|
||||
def create_node_succeeded_event(self) -> QueueNodeSucceededEvent:
|
||||
"""Create a QueueNodeSucceededEvent for testing."""
|
||||
return QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_data=CodeNodeData(
|
||||
title="test code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def create_node_retry_event(self) -> QueueNodeRetryEvent:
|
||||
"""Create a QueueNodeRetryEvent for testing."""
|
||||
return QueueNodeRetryEvent(
|
||||
inputs={"data": "inputs"},
|
||||
outputs={"data": "outputs"},
|
||||
error="oops",
|
||||
retry_index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_data=CodeNodeData(
|
||||
title="test code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
def test_workflow_node_finish_response_uses_truncated_process_data(self):
|
||||
"""Test that node finish response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_finish_response_without_truncation(self):
|
||||
"""Test node finish response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_finish_response_with_none_process_data(self):
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should have None process_data
|
||||
assert response is not None
|
||||
assert response.data.process_data is None
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_workflow_node_retry_response_uses_truncated_process_data(self):
|
||||
"""Test that node retry response uses get_response_process_data()."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
execution = self.create_workflow_node_execution(
|
||||
process_data=original_data, truncated_process_data=truncated_data
|
||||
)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use truncated data, not original
|
||||
assert response is not None
|
||||
assert response.data.process_data == truncated_data
|
||||
assert response.data.process_data != original_data
|
||||
assert response.data.process_data_truncated is True
|
||||
|
||||
def test_workflow_node_retry_response_without_truncation(self):
|
||||
"""Test node retry response when no truncation is applied."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Response should use original data
|
||||
assert response is not None
|
||||
assert response.data.process_data == original_data
|
||||
assert response.data.process_data_truncated is False
|
||||
|
||||
def test_iteration_and_loop_nodes_return_none(self):
|
||||
"""Test that iteration and loop nodes return None (no change from existing behavior)."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
# Test iteration node
|
||||
iteration_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
iteration_execution.node_type = NodeType.ITERATION
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=iteration_execution,
|
||||
)
|
||||
|
||||
# Should return None for iteration nodes
|
||||
assert response is None
|
||||
|
||||
# Test loop node
|
||||
loop_execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
loop_execution.node_type = NodeType.LOOP
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=loop_execution,
|
||||
)
|
||||
|
||||
# Should return None for loop nodes
|
||||
assert response is None
|
||||
|
||||
def test_execution_without_workflow_execution_id_returns_none(self):
|
||||
"""Test that executions without workflow_execution_id return None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data={"test": "data"})
|
||||
execution.workflow_execution_id = None # Single-step debugging
|
||||
|
||||
event = self.create_node_succeeded_event()
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
# Should return None for single-step debugging
|
||||
assert response is None
|
||||
|
||||
@staticmethod
|
||||
def get_process_data_response_scenarios() -> list[ProcessDataResponseScenario]:
|
||||
"""Create test scenarios for process_data responses."""
|
||||
return [
|
||||
ProcessDataResponseScenario(
|
||||
name="none_process_data",
|
||||
original_process_data=None,
|
||||
truncated_process_data=None,
|
||||
expected_response_data=None,
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="small_process_data_no_truncation",
|
||||
original_process_data={"small": "data"},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={"small": "data"},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="large_process_data_with_truncation",
|
||||
original_process_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_process_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="empty_process_data",
|
||||
original_process_data={},
|
||||
truncated_process_data=None,
|
||||
expected_response_data={},
|
||||
expected_truncated_flag=False,
|
||||
),
|
||||
ProcessDataResponseScenario(
|
||||
name="complex_data_with_truncation",
|
||||
original_process_data={
|
||||
"logs": ["entry"] * 1000, # Large array
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
truncated_process_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_response_data={
|
||||
"logs": "[TRUNCATED: 1000 items]",
|
||||
"config": {"setting": "value"},
|
||||
"status": "processing",
|
||||
},
|
||||
expected_truncated_flag=True,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[scenario for scenario in get_process_data_response_scenarios()],
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_finish_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node finish responses."""
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant"))
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.CODE,
|
||||
node_data=CodeNodeData(
|
||||
title="test code",
|
||||
variables=[],
|
||||
code_language=CodeLanguage.PYTHON3,
|
||||
code="",
|
||||
outputs={},
|
||||
),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
start_at=naive_utc_now(),
|
||||
parallel_id=None,
|
||||
parallel_start_node_id=None,
|
||||
parent_parallel_id=None,
|
||||
parent_parallel_start_node_id=None,
|
||||
in_iteration_id=None,
|
||||
in_loop_id=None,
|
||||
)
|
||||
|
||||
response = converter.workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[scenario for scenario in get_process_data_response_scenarios()],
|
||||
ids=[scenario.name for scenario in get_process_data_response_scenarios()],
|
||||
)
|
||||
def test_node_retry_response_scenarios(self, scenario: ProcessDataResponseScenario):
|
||||
"""Test various scenarios for node retry responses."""
|
||||
converter = WorkflowResponseConverter(
|
||||
application_generate_entity=Mock(spec=WorkflowAppGenerateEntity, app_config=Mock(tenant_id="test-tenant"))
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-run-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_process_data,
|
||||
status=WorkflowNodeExecutionStatus.FAILED, # Retry scenario
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_process_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_process_data)
|
||||
|
||||
event = self.create_node_retry_event()
|
||||
|
||||
response = converter.workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id="test-task-id",
|
||||
workflow_node_execution=execution,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.data.process_data == scenario.expected_response_data
|
||||
assert response.data.process_data_truncated == scenario.expected_truncated_flag
|
||||
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution truncation functionality.
|
||||
|
||||
Tests the truncation and offloading logic for large inputs and outputs
|
||||
in the SQLAlchemyWorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import ExecutionOffLoadType
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
TRUNCATION_SIZE_THRESHOLD = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationTestCase:
|
||||
"""Test case data for truncation scenarios."""
|
||||
|
||||
name: str
|
||||
inputs: dict[str, Any] | None
|
||||
outputs: dict[str, Any] | None
|
||||
should_truncate_inputs: bool
|
||||
should_truncate_outputs: bool
|
||||
description: str
|
||||
|
||||
|
||||
def create_test_cases() -> list[TruncationTestCase]:
|
||||
"""Create test cases for different truncation scenarios."""
|
||||
# Create large data that will definitely exceed the threshold (10KB)
|
||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)}
|
||||
small_data = {"data": "small"}
|
||||
|
||||
return [
|
||||
TruncationTestCase(
|
||||
name="small_data_no_truncation",
|
||||
inputs=small_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="Small data should not be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_inputs_truncation",
|
||||
inputs=large_data,
|
||||
outputs=small_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=False,
|
||||
description="Large inputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_outputs_truncation",
|
||||
inputs=small_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=True,
|
||||
description="Large outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="large_both_truncation",
|
||||
inputs=large_data,
|
||||
outputs=large_data,
|
||||
should_truncate_inputs=True,
|
||||
should_truncate_outputs=True,
|
||||
description="Both large inputs and outputs should be truncated",
|
||||
),
|
||||
TruncationTestCase(
|
||||
name="none_inputs_outputs",
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
should_truncate_inputs=False,
|
||||
should_truncate_outputs=False,
|
||||
description="None inputs and outputs should not be truncated",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_workflow_node_execution(
|
||||
execution_id: str = "test-execution-id",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Factory function to create a WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
node_execution_id="test-node-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
workflow_execution_id="test-workflow-execution-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
def mock_user() -> Account:
|
||||
"""Create a mock Account user for testing."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "test-user-id"
|
||||
user.current_tenant_id = "test-tenant-id"
|
||||
return user
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
||||
"""Test class for truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_repository(self) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=MagicMock(spec=Engine),
|
||||
user=mock_user(),
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model correctly handles models without offload data."""
|
||||
repo = self.create_repository()
|
||||
|
||||
# Create a mock database model without offload data
|
||||
db_model = WorkflowNodeExecutionModel()
|
||||
db_model.id = "test-id"
|
||||
db_model.node_execution_id = "node-exec-id"
|
||||
db_model.workflow_id = "workflow-id"
|
||||
db_model.workflow_run_id = "run-id"
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "node-id"
|
||||
db_model.node_type = NodeType.LLM.value
|
||||
db_model.title = "Test Node"
|
||||
db_model.inputs = json.dumps({"value": "inputs"})
|
||||
db_model.process_data = json.dumps({"value": "process_data"})
|
||||
db_model.outputs = json.dumps({"value": "outputs"})
|
||||
db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.0
|
||||
db_model.execution_metadata = "{}"
|
||||
db_model.created_at = datetime.now(UTC)
|
||||
db_model.finished_at = None
|
||||
db_model.offload_data = []
|
||||
|
||||
domain_model = repo._to_domain_model(db_model)
|
||||
|
||||
# Check that no truncated data was set
|
||||
assert domain_model.get_truncated_inputs() is None
|
||||
assert domain_model.get_truncated_outputs() is None
|
||||
|
||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.FileService")
|
||||
def test_save_with_truncation(self, mock_file_service_class):
|
||||
"""Test the save method handles truncation and offload record creation."""
|
||||
# Setup mock file service
|
||||
mock_file_service = MagicMock()
|
||||
mock_upload_file = MagicMock()
|
||||
mock_upload_file.id = "mock-file-id"
|
||||
mock_file_service.upload_file.return_value = mock_upload_file
|
||||
mock_file_service_class.return_value = mock_file_service
|
||||
|
||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1)}
|
||||
|
||||
repo = self.create_repository()
|
||||
execution = create_workflow_node_execution(
|
||||
inputs=large_data,
|
||||
outputs=large_data,
|
||||
)
|
||||
|
||||
# Mock the session and database operations
|
||||
with patch.object(repo, "_session_factory") as mock_session_factory:
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
|
||||
repo.save(execution)
|
||||
|
||||
# Check that both merge operations were called (db_model and offload_record)
|
||||
assert mock_session.merge.call_count == 1
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
||||
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
||||
|
||||
def test_inputs_truncated_with_offload_data(self):
|
||||
"""Test inputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is True
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_outputs_truncated_with_offload_data(self):
|
||||
"""Test outputs_truncated property when offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
|
||||
# Mock offload data with outputs file
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS)
|
||||
model.offload_data = [offload]
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
assert model.outputs_truncated is True
|
||||
|
||||
def test_process_data_truncated_with_offload_data(self):
|
||||
model = WorkflowNodeExecutionModel()
|
||||
offload = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA)
|
||||
model.offload_data = [offload]
|
||||
assert model.process_data_truncated is True
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_data(self):
|
||||
"""Test truncated properties when no offload data exists."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
model.offload_data = []
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
|
||||
def test_truncated_properties_without_offload_attribute(self):
|
||||
"""Test truncated properties when offload_data attribute doesn't exist."""
|
||||
model = WorkflowNodeExecutionModel()
|
||||
# Don't set offload_data attribute at all
|
||||
|
||||
assert model.inputs_truncated is False
|
||||
assert model.outputs_truncated is False
|
||||
assert model.process_data_truncated is False
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataTruncation:
|
||||
"""Test process_data truncation functionality in WorkflowNodeExecution domain model."""
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, Any] | None = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
def test_initial_process_data_truncated_state(self):
|
||||
"""Test that process_data_truncated returns False initially."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_set_and_get_truncated_process_data(self):
|
||||
"""Test setting and getting truncated process_data."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
test_truncated_data = {"truncated": True, "key": "value"}
|
||||
|
||||
execution.set_truncated_process_data(test_truncated_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_truncated_data
|
||||
|
||||
def test_set_truncated_process_data_to_none(self):
|
||||
"""Test setting truncated process_data to None."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# First set some data
|
||||
execution.set_truncated_process_data({"key": "value"})
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
# Then set to None
|
||||
execution.set_truncated_process_data(None)
|
||||
assert execution.process_data_truncated is False
|
||||
assert execution.get_truncated_process_data() is None
|
||||
|
||||
def test_get_response_process_data_with_no_truncation(self):
|
||||
"""Test get_response_process_data when no truncation is set."""
|
||||
original_data = {"original": True, "data": "value"}
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data == original_data
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_get_response_process_data_with_truncation(self):
|
||||
"""Test get_response_process_data when truncation is set."""
|
||||
original_data = {"original": True, "large_data": "x" * 10000}
|
||||
truncated_data = {"original": True, "large_data": "[TRUNCATED]"}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=original_data)
|
||||
execution.set_truncated_process_data(truncated_data)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
# Should return truncated data, not original
|
||||
assert response_data == truncated_data
|
||||
assert response_data != original_data
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_get_response_process_data_with_none_process_data(self):
|
||||
"""Test get_response_process_data when process_data is None."""
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
|
||||
response_data = execution.get_response_process_data()
|
||||
|
||||
assert response_data is None
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_consistency_with_inputs_outputs_pattern(self):
|
||||
"""Test that process_data truncation follows the same pattern as inputs/outputs."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
# Test that all truncation methods exist and behave consistently
|
||||
test_data = {"test": "data"}
|
||||
|
||||
# Test inputs truncation
|
||||
execution.set_truncated_inputs(test_data)
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.get_truncated_inputs() == test_data
|
||||
|
||||
# Test outputs truncation
|
||||
execution.set_truncated_outputs(test_data)
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.get_truncated_outputs() == test_data
|
||||
|
||||
# Test process_data truncation
|
||||
execution.set_truncated_process_data(test_data)
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
{"simple": "value"},
|
||||
{"nested": {"key": "value"}},
|
||||
{"list": [1, 2, 3]},
|
||||
{"mixed": {"string": "value", "number": 42, "list": [1, 2]}},
|
||||
{}, # empty dict
|
||||
],
|
||||
)
|
||||
def test_truncated_process_data_with_various_data_types(self, test_data):
|
||||
"""Test that truncated process_data works with various data types."""
|
||||
execution = self.create_workflow_node_execution()
|
||||
|
||||
execution.set_truncated_process_data(test_data)
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == test_data
|
||||
assert execution.get_response_process_data() == test_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessDataScenario:
|
||||
"""Test scenario data for process_data functionality."""
|
||||
|
||||
name: str
|
||||
original_data: dict[str, Any] | None
|
||||
truncated_data: dict[str, Any] | None
|
||||
expected_truncated_flag: bool
|
||||
expected_response_data: dict[str, Any] | None
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionProcessDataScenarios:
|
||||
"""Test various scenarios for process_data handling."""
|
||||
|
||||
def get_process_data_scenarios(self) -> list[ProcessDataScenario]:
|
||||
"""Create test scenarios for process_data functionality."""
|
||||
return [
|
||||
ProcessDataScenario(
|
||||
name="no_process_data",
|
||||
original_data=None,
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data=None,
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_without_truncation",
|
||||
original_data={"small": "data"},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={"small": "data"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="process_data_with_truncation",
|
||||
original_data={"large": "x" * 10000, "metadata": "info"},
|
||||
truncated_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={"large": "[TRUNCATED]", "metadata": "info"},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="empty_process_data",
|
||||
original_data={},
|
||||
truncated_data=None,
|
||||
expected_truncated_flag=False,
|
||||
expected_response_data={},
|
||||
),
|
||||
ProcessDataScenario(
|
||||
name="complex_nested_data_with_truncation",
|
||||
original_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": ["log1", "log2"] * 1000, # Large list
|
||||
"status": "running",
|
||||
},
|
||||
truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"},
|
||||
expected_truncated_flag=True,
|
||||
expected_response_data={
|
||||
"config": {"setting": "value"},
|
||||
"logs": "[TRUNCATED: 2000 items]",
|
||||
"status": "running",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[scenario for scenario in get_process_data_scenarios(None)],
|
||||
ids=[scenario.name for scenario in get_process_data_scenarios(None)],
|
||||
)
|
||||
def test_process_data_scenarios(self, scenario: ProcessDataScenario):
|
||||
"""Test various process_data scenarios."""
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.original_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
if scenario.truncated_data is not None:
|
||||
execution.set_truncated_process_data(scenario.truncated_data)
|
||||
|
||||
assert execution.process_data_truncated == scenario.expected_truncated_flag
|
||||
assert execution.get_response_process_data() == scenario.expected_response_data
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Unit tests for WorkflowNodeExecutionOffload model, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionOffload:
|
||||
"""Test WorkflowNodeExecutionOffload model with process_data fields."""
|
||||
|
||||
def test_get_exe(self):
|
||||
WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestWorkflowNodeExecutionModel:
|
||||
"""Test WorkflowNodeExecutionModel with process_data truncation features."""
|
||||
|
||||
def create_mock_offload_data(
|
||||
self,
|
||||
inputs_file_id: str | None = None,
|
||||
outputs_file_id: str | None = None,
|
||||
process_data_file_id: str | None = None,
|
||||
) -> WorkflowNodeExecutionOffload:
|
||||
"""Create a mock offload data object."""
|
||||
offload = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
offload.inputs_file_id = inputs_file_id
|
||||
offload.outputs_file_id = outputs_file_id
|
||||
offload.process_data_file_id = process_data_file_id
|
||||
|
||||
# Mock file objects
|
||||
if inputs_file_id:
|
||||
offload.inputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.inputs_file = None
|
||||
|
||||
if outputs_file_id:
|
||||
offload.outputs_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.outputs_file = None
|
||||
|
||||
if process_data_file_id:
|
||||
offload.process_data_file = Mock(spec=UploadFile)
|
||||
else:
|
||||
offload.process_data_file = None
|
||||
|
||||
return offload
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_offload_data(self):
|
||||
"""Test process_data_truncated returns False when no offload_data."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = None
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_false_when_no_process_data_file(self):
|
||||
"""Test process_data_truncated returns False when no process_data file."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create real offload instance
|
||||
offload_data = WorkflowNodeExecutionOffload()
|
||||
offload_data.inputs_file_id = "inputs-file"
|
||||
offload_data.outputs_file_id = "outputs-file"
|
||||
offload_data.process_data_file_id = None # No process_data file
|
||||
execution.offload_data = offload_data
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
|
||||
def test_process_data_truncated_property_true_when_process_data_file_exists(self):
|
||||
"""Test process_data_truncated returns True when process_data file exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Create a real offload instance rather than mock
|
||||
offload_data = WorkflowNodeExecutionOffload()
|
||||
offload_data.process_data_file_id = "process-data-file-id"
|
||||
execution.offload_data = offload_data
|
||||
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_load_full_process_data_with_no_offload_data(self):
|
||||
"""Test load_full_process_data when no offload data exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = None
|
||||
execution.process_data_dict = {"test": "data"}
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_no_file(self):
|
||||
"""Test load_full_process_data when no process_data file exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
execution.offload_data = self.create_mock_offload_data(process_data_file_id=None)
|
||||
execution.process_data_dict = {"test": "data"}
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == {"test": "data"}
|
||||
|
||||
def test_load_full_process_data_with_file(self):
|
||||
"""Test load_full_process_data when process_data file exists."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
offload_data = self.create_mock_offload_data(process_data_file_id="file-id")
|
||||
execution.offload_data = offload_data
|
||||
execution.process_data_dict = {"truncated": "data"}
|
||||
|
||||
# Mock session and storage
|
||||
mock_session = Mock()
|
||||
mock_storage = Mock()
|
||||
|
||||
# Mock the _load_full_content method to return full data
|
||||
full_process_data = {"full": "data", "large_field": "x" * 10000}
|
||||
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
# Mock the _load_full_content method
|
||||
def mock_load_full_content(session, file_id, storage):
|
||||
assert session == mock_session
|
||||
assert file_id == "file-id"
|
||||
assert storage == mock_storage
|
||||
return full_process_data
|
||||
|
||||
mp.setattr(execution, "_load_full_content", mock_load_full_content)
|
||||
|
||||
result = execution.load_full_process_data(mock_session, mock_storage)
|
||||
|
||||
assert result == full_process_data
|
||||
|
||||
def test_consistency_with_inputs_outputs_truncation(self):
|
||||
"""Test that process_data truncation behaves consistently with inputs/outputs."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Test all three truncation properties together
|
||||
offload_data = self.create_mock_offload_data(
|
||||
inputs_file_id="inputs-file", outputs_file_id="outputs-file", process_data_file_id="process-data-file"
|
||||
)
|
||||
execution.offload_data = offload_data
|
||||
|
||||
# All should be truncated
|
||||
assert execution.inputs_truncated is True
|
||||
assert execution.outputs_truncated is True
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_mixed_truncation_states(self):
|
||||
"""Test mixed states of truncation."""
|
||||
execution = WorkflowNodeExecutionModel()
|
||||
|
||||
# Only process_data is truncated
|
||||
offload_data = self.create_mock_offload_data(
|
||||
inputs_file_id=None, outputs_file_id=None, process_data_file_id="process-data-file"
|
||||
)
|
||||
execution.offload_data = offload_data
|
||||
|
||||
assert execution.inputs_truncated is False
|
||||
assert execution.outputs_truncated is False
|
||||
assert execution.process_data_truncated is True
|
||||
|
||||
def test_preload_offload_data_and_files_method_exists(self):
|
||||
"""Test that the preload method includes process_data_file."""
|
||||
# This test verifies the method exists and can be called
|
||||
# The actual SQL behavior would be tested in integration tests
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(WorkflowNodeExecutionModel)
|
||||
|
||||
# This should not raise an exception
|
||||
preloaded_stmt = WorkflowNodeExecutionModel.preload_offload_data_and_files(stmt)
|
||||
|
||||
# The statement should be modified (different object)
|
||||
assert preloaded_stmt is not stmt
|
||||
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||
_InputsOutputsTruncationResult,
|
||||
)
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||
"""Test process_data truncation functionality in SQLAlchemyWorkflowNodeExecutionRepository."""
|
||||
|
||||
def create_mock_account(self) -> Account:
|
||||
"""Create a mock Account for testing."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = "test-user-id"
|
||||
account.tenant_id = "test-tenant-id"
|
||||
return account
|
||||
|
||||
def create_mock_session_factory(self) -> sessionmaker:
|
||||
"""Create a mock session factory for testing."""
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.return_value.__exit__.return_value = None
|
||||
return mock_session_factory
|
||||
|
||||
def create_repository(self, mock_file_service=None) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""Create a repository instance for testing."""
|
||||
mock_account = self.create_mock_account()
|
||||
mock_session_factory = self.create_mock_session_factory()
|
||||
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if mock_file_service:
|
||||
repository._file_service = mock_file_service
|
||||
|
||||
return repository
|
||||
|
||||
def create_workflow_node_execution(
|
||||
self,
|
||||
process_data: dict[str, any] | None = None,
|
||||
execution_id: str = "test-execution-id",
|
||||
) -> WorkflowNodeExecution:
|
||||
"""Create a WorkflowNodeExecution instance for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=execution_id,
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
@patch('core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config')
|
||||
def test_to_db_model_with_small_process_data(self, mock_config):
|
||||
"""Test _to_db_model with small process_data that doesn't need truncation."""
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
repository = self.create_repository()
|
||||
small_process_data = {"small": "data", "count": 5}
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=small_process_data)
|
||||
|
||||
with patch.object(repository, '_truncate_and_upload', return_value=None) as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should try to truncate but return None (no truncation needed)
|
||||
mock_truncate.assert_called_once_with(
|
||||
small_process_data,
|
||||
execution.id,
|
||||
"_process_data"
|
||||
)
|
||||
|
||||
# Process data should be stored directly in database
|
||||
assert db_model.process_data is not None
|
||||
stored_data = json.loads(db_model.process_data)
|
||||
assert stored_data == small_process_data
|
||||
|
||||
# No offload data should be created for process_data
|
||||
assert db_model.offload_data is None
|
||||
|
||||
def test_to_db_model_with_large_process_data(self):
|
||||
"""Test _to_db_model with large process_data that needs truncation."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create large process_data that would need truncation
|
||||
large_process_data = {
|
||||
"large_field": "x" * 10000, # Very large string
|
||||
"metadata": {"type": "processing", "timestamp": 1234567890}
|
||||
}
|
||||
|
||||
# Mock truncation result
|
||||
truncated_data = {
|
||||
"large_field": "[TRUNCATED]",
|
||||
"metadata": {"type": "processing", "timestamp": 1234567890}
|
||||
}
|
||||
|
||||
mock_upload_file = Mock(spec=UploadFile)
|
||||
mock_upload_file.id = "mock-file-id"
|
||||
|
||||
truncation_result = _InputsOutputsTruncationResult(
|
||||
truncated_value=truncated_data,
|
||||
file=mock_upload_file
|
||||
)
|
||||
|
||||
execution = self.create_workflow_node_execution(process_data=large_process_data)
|
||||
|
||||
with patch.object(repository, '_truncate_and_upload', return_value=truncation_result) as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should call truncate with correct parameters
|
||||
mock_truncate.assert_called_once_with(
|
||||
large_process_data,
|
||||
execution.id,
|
||||
"_process_data"
|
||||
)
|
||||
|
||||
# Truncated data should be stored in database
|
||||
assert db_model.process_data is not None
|
||||
stored_data = json.loads(db_model.process_data)
|
||||
assert stored_data == truncated_data
|
||||
|
||||
# Domain model should have truncated data set
|
||||
assert execution.process_data_truncated is True
|
||||
assert execution.get_truncated_process_data() == truncated_data
|
||||
|
||||
# Offload data should be created
|
||||
assert db_model.offload_data is not None
|
||||
assert db_model.offload_data.process_data_file == mock_upload_file
|
||||
assert db_model.offload_data.process_data_file_id == "mock-file-id"
|
||||
|
||||
def test_to_db_model_with_none_process_data(self):
|
||||
"""Test _to_db_model with None process_data."""
|
||||
repository = self.create_repository()
|
||||
execution = self.create_workflow_node_execution(process_data=None)
|
||||
|
||||
with patch.object(repository, '_truncate_and_upload') as mock_truncate:
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should not call truncate for None data
|
||||
mock_truncate.assert_not_called()
|
||||
|
||||
# Process data should be None
|
||||
assert db_model.process_data is None
|
||||
|
||||
# No offload data should be created
|
||||
assert db_model.offload_data is None
|
||||
|
||||
def test_to_domain_model_with_offloaded_process_data(self):
|
||||
"""Test _to_domain_model with offloaded process_data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model with offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
# Mock truncated process_data from database
|
||||
truncated_process_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
db_model.process_data_dict = truncated_process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
|
||||
# Mock offload data with process_data file
|
||||
mock_offload_data = Mock(spec=WorkflowNodeExecutionOffload)
|
||||
mock_offload_data.inputs_file_id = None
|
||||
mock_offload_data.inputs_file = None
|
||||
mock_offload_data.outputs_file_id = None
|
||||
mock_offload_data.outputs_file = None
|
||||
mock_offload_data.process_data_file_id = "process-data-file-id"
|
||||
|
||||
mock_process_data_file = Mock(spec=UploadFile)
|
||||
mock_offload_data.process_data_file = mock_process_data_file
|
||||
|
||||
db_model.offload_data = mock_offload_data
|
||||
|
||||
# Mock the file loading
|
||||
original_process_data = {
|
||||
"large_field": "x" * 10000,
|
||||
"metadata": "info"
|
||||
}
|
||||
|
||||
with patch.object(repository, '_load_file', return_value=original_process_data) as mock_load:
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Should load the file
|
||||
mock_load.assert_called_once_with(mock_process_data_file)
|
||||
|
||||
# Domain model should have original data
|
||||
assert domain_model.process_data == original_process_data
|
||||
|
||||
# Domain model should have truncated data set
|
||||
assert domain_model.process_data_truncated is True
|
||||
assert domain_model.get_truncated_process_data() == truncated_process_data
|
||||
|
||||
def test_to_domain_model_without_offload_data(self):
|
||||
"""Test _to_domain_model without offload data."""
|
||||
repository = self.create_repository()
|
||||
|
||||
# Create mock database model without offload data
|
||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
db_model.id = "test-execution-id"
|
||||
db_model.node_execution_id = "test-node-execution-id"
|
||||
db_model.workflow_id = "test-workflow-id"
|
||||
db_model.workflow_run_id = None
|
||||
db_model.index = 1
|
||||
db_model.predecessor_node_id = None
|
||||
db_model.node_id = "test-node-id"
|
||||
db_model.node_type = "llm"
|
||||
db_model.title = "Test Node"
|
||||
db_model.status = "succeeded"
|
||||
db_model.error = None
|
||||
db_model.elapsed_time = 1.5
|
||||
db_model.created_at = datetime.now()
|
||||
db_model.finished_at = None
|
||||
|
||||
process_data = {"normal": "data"}
|
||||
db_model.process_data_dict = process_data
|
||||
db_model.inputs_dict = None
|
||||
db_model.outputs_dict = None
|
||||
db_model.execution_metadata_dict = {}
|
||||
db_model.offload_data = None
|
||||
|
||||
domain_model = repository._to_domain_model(db_model)
|
||||
|
||||
# Domain model should have the data from database
|
||||
assert domain_model.process_data == process_data
|
||||
|
||||
# Should not be truncated
|
||||
assert domain_model.process_data_truncated is False
|
||||
assert domain_model.get_truncated_process_data() is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TruncationScenario:
|
||||
"""Test scenario for truncation functionality."""
|
||||
name: str
|
||||
process_data: dict[str, any] | None
|
||||
should_truncate: bool
|
||||
expected_truncated: bool = False
|
||||
|
||||
|
||||
class TestProcessDataTruncationScenarios:
|
||||
"""Test various scenarios for process_data truncation."""
|
||||
|
||||
def get_truncation_scenarios(self) -> list[TruncationScenario]:
|
||||
"""Create test scenarios for truncation."""
|
||||
return [
|
||||
TruncationScenario(
|
||||
name="none_data",
|
||||
process_data=None,
|
||||
should_truncate=False,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="small_data",
|
||||
process_data={"key": "value"},
|
||||
should_truncate=False,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="large_data",
|
||||
process_data={"large": "x" * 10000},
|
||||
should_truncate=True,
|
||||
expected_truncated=True,
|
||||
),
|
||||
TruncationScenario(
|
||||
name="empty_data",
|
||||
process_data={},
|
||||
should_truncate=False,
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("scenario", [
|
||||
scenario for scenario in get_truncation_scenarios(None)
|
||||
], ids=[scenario.name for scenario in get_truncation_scenarios(None)])
|
||||
def test_process_data_truncation_scenarios(self, scenario: TruncationScenario):
|
||||
"""Test various process_data truncation scenarios."""
|
||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=MagicMock(spec=sessionmaker),
|
||||
user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"),
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
execution = WorkflowNodeExecution(
|
||||
id="test-execution-id",
|
||||
workflow_id="test-workflow-id",
|
||||
index=1,
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
title="Test Node",
|
||||
process_data=scenario.process_data,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
|
||||
# Mock truncation behavior
|
||||
if scenario.should_truncate:
|
||||
truncated_data = {"truncated": True}
|
||||
mock_file = Mock(spec=UploadFile, id="file-id")
|
||||
truncation_result = _InputsOutputsTruncationResult(
|
||||
truncated_value=truncated_data,
|
||||
file=mock_file
|
||||
)
|
||||
|
||||
with patch.object(repository, '_truncate_and_upload', return_value=truncation_result):
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should create offload data
|
||||
assert db_model.offload_data is not None
|
||||
assert db_model.offload_data.process_data_file_id == "file-id"
|
||||
assert execution.process_data_truncated == scenario.expected_truncated
|
||||
else:
|
||||
with patch.object(repository, '_truncate_and_upload', return_value=None):
|
||||
db_model = repository._to_db_model(execution)
|
||||
|
||||
# Should not create offload data or set truncation
|
||||
if scenario.process_data is None:
|
||||
assert db_model.offload_data is None
|
||||
assert db_model.process_data is None
|
||||
else:
|
||||
# For small data, might have offload_data from other fields but not process_data
|
||||
if db_model.offload_data:
|
||||
assert db_model.offload_data.process_data_file_id is None
|
||||
assert db_model.offload_data.process_data_file is None
|
||||
|
||||
assert execution.process_data_truncated is False
|
||||
709
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
709
api/tests/unit_tests/services/test_variable_truncator.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
Comprehensive unit tests for VariableTruncator class based on current implementation.
|
||||
|
||||
This test suite covers all functionality of the current VariableTruncator including:
|
||||
- JSON size calculation for different data types
|
||||
- String, array, and object truncation logic
|
||||
- Segment-based truncation interface
|
||||
- Helper methods for budget-based truncation
|
||||
- Edge cases and error handling
|
||||
"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from services.variable_truncator import (
|
||||
ARRAY_CHAR_LIMIT,
|
||||
LARGE_VARIABLE_THRESHOLD,
|
||||
OBJECT_CHAR_LIMIT,
|
||||
MaxDepthExceededError,
|
||||
TruncationResult,
|
||||
UnknownTypeError,
|
||||
VariableTruncator,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file() -> File:
|
||||
return File(
|
||||
id=str(uuid4()), # Generate new UUID for File.id
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="initial_key",
|
||||
)
|
||||
|
||||
|
||||
_compact_json_dumps = functools.partial(json.dumps, separators=(",", ":"))
|
||||
|
||||
|
||||
class TestCalculateJsonSize:
|
||||
"""Test calculate_json_size method with different data types."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
def test_string_size_calculation(self):
|
||||
"""Test JSON size calculation for strings."""
|
||||
# Simple ASCII string
|
||||
assert VariableTruncator.calculate_json_size("hello") == 7 # "hello" + 2 quotes
|
||||
|
||||
# Empty string
|
||||
assert VariableTruncator.calculate_json_size("") == 2 # Just quotes
|
||||
|
||||
# Unicode string
|
||||
unicode_text = "你好"
|
||||
expected_size = len(unicode_text.encode("utf-8")) + 2
|
||||
assert VariableTruncator.calculate_json_size(unicode_text) == expected_size
|
||||
|
||||
def test_number_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for numbers."""
|
||||
assert truncator.calculate_json_size(123) == 3
|
||||
assert truncator.calculate_json_size(12.34) == 5
|
||||
assert truncator.calculate_json_size(-456) == 4
|
||||
assert truncator.calculate_json_size(0) == 1
|
||||
|
||||
def test_boolean_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for booleans."""
|
||||
assert truncator.calculate_json_size(True) == 4 # "true"
|
||||
assert truncator.calculate_json_size(False) == 5 # "false"
|
||||
|
||||
def test_null_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for None/null."""
|
||||
assert truncator.calculate_json_size(None) == 4 # "null"
|
||||
|
||||
def test_array_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for arrays."""
|
||||
# Empty array
|
||||
assert truncator.calculate_json_size([]) == 2 # "[]"
|
||||
|
||||
# Simple array
|
||||
simple_array = [1, 2, 3]
|
||||
# [1,2,3] = 1 + 1 + 1 + 1 + 1 + 2 = 7 (numbers + commas + brackets)
|
||||
assert truncator.calculate_json_size(simple_array) == 7
|
||||
|
||||
# Array with strings
|
||||
string_array = ["a", "b"]
|
||||
# ["a","b"] = 3 + 3 + 1 + 2 = 9 (quoted strings + comma + brackets)
|
||||
assert truncator.calculate_json_size(string_array) == 9
|
||||
|
||||
def test_object_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for objects."""
|
||||
# Empty object
|
||||
assert truncator.calculate_json_size({}) == 2 # "{}"
|
||||
|
||||
# Simple object
|
||||
simple_obj = {"a": 1}
|
||||
# {"a":1} = 3 + 1 + 1 + 2 = 7 (key + colon + value + brackets)
|
||||
assert truncator.calculate_json_size(simple_obj) == 7
|
||||
|
||||
# Multiple keys
|
||||
multi_obj = {"a": 1, "b": 2}
|
||||
# {"a":1,"b":2} = 3 + 1 + 1 + 1 + 3 + 1 + 1 + 2 = 13
|
||||
assert truncator.calculate_json_size(multi_obj) == 13
|
||||
|
||||
def test_nested_structure_size_calculation(self, truncator):
|
||||
"""Test JSON size calculation for nested structures."""
|
||||
nested = {"items": [1, 2, {"nested": "value"}]}
|
||||
size = truncator.calculate_json_size(nested)
|
||||
assert size > 0 # Should calculate without error
|
||||
|
||||
# Verify it matches actual JSON length roughly
|
||||
|
||||
actual_json = _compact_json_dumps(nested)
|
||||
# Should be close but not exact due to UTF-8 encoding considerations
|
||||
assert abs(size - len(actual_json.encode())) <= 5
|
||||
|
||||
def test_calculate_json_size_max_depth_exceeded(self, truncator):
|
||||
"""Test that calculate_json_size handles deep nesting gracefully."""
|
||||
# Create deeply nested structure
|
||||
nested: dict[str, Any] = {"level": 0}
|
||||
current = nested
|
||||
for i in range(25): # Create deep nesting
|
||||
current["next"] = {"level": i + 1}
|
||||
current = current["next"]
|
||||
|
||||
# Should either raise an error or handle gracefully
|
||||
with pytest.raises(MaxDepthExceededError):
|
||||
truncator.calculate_json_size(nested)
|
||||
|
||||
def test_calculate_json_size_unknown_type(self, truncator):
|
||||
"""Test that calculate_json_size raises error for unknown types."""
|
||||
|
||||
class CustomType:
|
||||
pass
|
||||
|
||||
with pytest.raises(UnknownTypeError):
|
||||
truncator.calculate_json_size(CustomType())
|
||||
|
||||
|
||||
class TestStringTruncation:
|
||||
"""Test string truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=10)
|
||||
|
||||
def test_short_string_no_truncation(self, small_truncator):
|
||||
"""Test that short strings are not truncated."""
|
||||
short_str = "hello"
|
||||
result, was_truncated = small_truncator._truncate_string(short_str)
|
||||
assert result == short_str
|
||||
assert was_truncated is False
|
||||
|
||||
def test_long_string_truncation(self, small_truncator: VariableTruncator):
|
||||
"""Test that long strings are truncated with ellipsis."""
|
||||
long_str = "this is a very long string that exceeds the limit"
|
||||
result, was_truncated = small_truncator._truncate_string(long_str)
|
||||
|
||||
assert was_truncated is True
|
||||
assert result == long_str[:7] + "..."
|
||||
assert len(result) == 10 # 10 chars + "..."
|
||||
|
||||
def test_exact_limit_string(self, small_truncator):
|
||||
"""Test string exactly at limit."""
|
||||
exact_str = "1234567890" # Exactly 10 chars
|
||||
result, was_truncated = small_truncator._truncate_string(exact_str)
|
||||
assert result == exact_str
|
||||
assert was_truncated is False
|
||||
|
||||
|
||||
class TestArrayTruncation:
|
||||
"""Test array truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(array_element_limit=3, max_size_bytes=100)
|
||||
|
||||
def test_small_array_no_truncation(self, small_truncator):
|
||||
"""Test that small arrays are not truncated."""
|
||||
small_array = [1, 2]
|
||||
result, was_truncated = small_truncator._truncate_array(small_array, 1000)
|
||||
assert result == small_array
|
||||
assert was_truncated is False
|
||||
|
||||
def test_array_element_limit_truncation(self, small_truncator):
|
||||
"""Test that arrays over element limit are truncated."""
|
||||
large_array = [1, 2, 3, 4, 5, 6] # Exceeds limit of 3
|
||||
result, was_truncated = small_truncator._truncate_array(large_array, 1000)
|
||||
|
||||
assert was_truncated is True
|
||||
assert len(result) == 3
|
||||
assert result == [1, 2, 3]
|
||||
|
||||
def test_array_size_budget_truncation(self, small_truncator):
|
||||
"""Test array truncation due to size budget constraints."""
|
||||
# Create array with strings that will exceed size budget
|
||||
large_strings = ["very long string " * 5, "another long string " * 5]
|
||||
result, was_truncated = small_truncator._truncate_array(large_strings, 50)
|
||||
|
||||
assert was_truncated is True
|
||||
# Should have truncated the strings within the array
|
||||
for item in result:
|
||||
assert isinstance(item, str)
|
||||
print(result)
|
||||
assert len(_compact_json_dumps(result).encode()) <= 50
|
||||
|
||||
def test_array_with_nested_objects(self, small_truncator):
|
||||
"""Test array truncation with nested objects."""
|
||||
nested_array = [
|
||||
{"name": "item1", "data": "some data"},
|
||||
{"name": "item2", "data": "more data"},
|
||||
{"name": "item3", "data": "even more data"},
|
||||
]
|
||||
result, was_truncated = small_truncator._truncate_array(nested_array, 80)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) <= 3
|
||||
# Should have processed nested objects appropriately
|
||||
|
||||
|
||||
class TestObjectTruncation:
|
||||
"""Test object truncation functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(max_size_bytes=100)
|
||||
|
||||
def test_small_object_no_truncation(self, small_truncator):
|
||||
"""Test that small objects are not truncated."""
|
||||
small_obj = {"a": 1, "b": 2}
|
||||
result, was_truncated = small_truncator._truncate_object(small_obj, 1000)
|
||||
assert result == small_obj
|
||||
assert was_truncated is False
|
||||
|
||||
def test_empty_object_no_truncation(self, small_truncator):
|
||||
"""Test that empty objects are not truncated."""
|
||||
empty_obj = {}
|
||||
result, was_truncated = small_truncator._truncate_object(empty_obj, 100)
|
||||
assert result == empty_obj
|
||||
assert was_truncated is False
|
||||
|
||||
def test_object_value_truncation(self, small_truncator):
|
||||
"""Test object truncation where values are truncated to fit budget."""
|
||||
obj_with_long_values = {
|
||||
"key1": "very long string " * 10,
|
||||
"key2": "another long string " * 10,
|
||||
"key3": "third long string " * 10,
|
||||
}
|
||||
result, was_truncated = small_truncator._truncate_object(obj_with_long_values, 80)
|
||||
|
||||
assert was_truncated is True
|
||||
assert isinstance(result, dict)
|
||||
|
||||
# Keys should be preserved (deterministic order due to sorting)
|
||||
if result: # Only check if result is not empty
|
||||
assert list(result.keys()) == sorted(result.keys())
|
||||
|
||||
# Values should be truncated if they exist
|
||||
for key, value in result.items():
|
||||
if isinstance(value, str):
|
||||
original_value = obj_with_long_values[key]
|
||||
# Value should be same or smaller
|
||||
assert len(value) <= len(original_value)
|
||||
|
||||
def test_object_key_dropping(self, small_truncator):
|
||||
"""Test object truncation where keys are dropped due to size constraints."""
|
||||
large_obj = {f"key{i:02d}": f"value{i}" for i in range(20)}
|
||||
result, was_truncated = small_truncator._truncate_object(large_obj, 50)
|
||||
|
||||
assert was_truncated is True
|
||||
assert len(result) < len(large_obj)
|
||||
|
||||
# Should maintain sorted key order
|
||||
result_keys = list(result.keys())
|
||||
assert result_keys == sorted(result_keys)
|
||||
|
||||
def test_object_with_nested_structures(self, small_truncator):
|
||||
"""Test object truncation with nested arrays and objects."""
|
||||
nested_obj = {"simple": "value", "array": [1, 2, 3, 4, 5], "nested": {"inner": "data", "more": ["a", "b", "c"]}}
|
||||
result, was_truncated = small_truncator._truncate_object(nested_obj, 60)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
# Should handle nested structures appropriately
|
||||
|
||||
|
||||
class TestSegmentBasedTruncation:
|
||||
"""Test the main truncate method that works with Segments."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
@pytest.fixture
|
||||
def small_truncator(self):
|
||||
return VariableTruncator(string_length_limit=20, array_element_limit=3, max_size_bytes=200)
|
||||
|
||||
def test_integer_segment_no_truncation(self, truncator):
|
||||
"""Test that integer segments are never truncated."""
|
||||
segment = IntegerSegment(value=12345)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_boolean_as_integer_segment(self, truncator):
|
||||
"""Test boolean values in IntegerSegment are converted to int."""
|
||||
segment = IntegerSegment(value=True)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert isinstance(result.result, IntegerSegment)
|
||||
assert result.result.value == 1 # True converted to 1
|
||||
|
||||
def test_float_segment_no_truncation(self, truncator):
|
||||
"""Test that float segments are never truncated."""
|
||||
segment = FloatSegment(value=123.456)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_none_segment_no_truncation(self, truncator):
|
||||
"""Test that None segments are never truncated."""
|
||||
segment = NoneSegment()
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that file segments are never truncated."""
|
||||
file_segment = FileSegment(value=file)
|
||||
result = truncator.truncate(file_segment)
|
||||
assert result.result == file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_array_file_segment_no_truncation(self, truncator, file):
|
||||
"""Test that array file segments are never truncated."""
|
||||
|
||||
array_file_segment = ArrayFileSegment(value=[file] * 20)
|
||||
result = truncator.truncate(array_file_segment)
|
||||
assert result.result == array_file_segment
|
||||
assert result.truncated is False
|
||||
|
||||
def test_string_segment_small_no_truncation(self, truncator):
|
||||
"""Test small string segments are not truncated."""
|
||||
segment = StringSegment(value="hello world")
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_string_segment_large_truncation(self, small_truncator):
|
||||
"""Test large string segments are truncated."""
|
||||
long_text = "this is a very long string that will definitely exceed the limit"
|
||||
segment = StringSegment(value=long_text)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) < len(long_text)
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_array_segment_small_no_truncation(self, truncator):
|
||||
"""Test small array segments are not truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
segment = build_segment([1, 2, 3])
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_array_segment_large_truncation(self, small_truncator):
|
||||
"""Test large array segments are truncated."""
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
large_array = list(range(10)) # Exceeds element limit of 3
|
||||
segment = build_segment(large_array)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ArraySegment)
|
||||
assert len(result.result.value) <= 3
|
||||
|
||||
def test_object_segment_small_no_truncation(self, truncator):
|
||||
"""Test small object segments are not truncated."""
|
||||
segment = ObjectSegment(value={"key": "value"})
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is False
|
||||
assert result.result == segment
|
||||
|
||||
def test_object_segment_large_truncation(self, small_truncator):
|
||||
"""Test large object segments are truncated."""
|
||||
large_obj = {f"key{i}": f"very long value {i}" * 5 for i in range(5)}
|
||||
segment = ObjectSegment(value=large_obj)
|
||||
result = small_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, ObjectSegment)
|
||||
# Object should be smaller or equal than original
|
||||
original_size = small_truncator.calculate_json_size(large_obj)
|
||||
result_size = small_truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
def test_final_size_fallback_to_json_string(self, small_truncator):
|
||||
"""Test final fallback when truncated result still exceeds size limit."""
|
||||
# Create data that will still be large after initial truncation
|
||||
large_nested_data = {"data": ["very long string " * 5] * 5, "more": {"nested": "content " * 20}}
|
||||
segment = ObjectSegment(value=large_nested_data)
|
||||
|
||||
# Use very small limit to force JSON string fallback
|
||||
tiny_truncator = VariableTruncator(max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be JSON string with possible truncation
|
||||
assert len(result.result.value) <= 53 # 50 + "..." = 53
|
||||
|
||||
def test_final_size_fallback_string_truncation(self, small_truncator):
|
||||
"""Test final fallback for string that still exceeds limit."""
|
||||
# Create very long string that exceeds string length limit
|
||||
very_long_string = "x" * 6000 # Exceeds default string_length_limit of 5000
|
||||
segment = StringSegment(value=very_long_string)
|
||||
|
||||
# Use small limit to test string fallback path
|
||||
tiny_truncator = VariableTruncator(string_length_limit=100, max_size_bytes=50)
|
||||
result = tiny_truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
# Should be truncated due to string limit or final size limit
|
||||
assert len(result.result.value) <= 1000 # Much smaller than original
|
||||
|
||||
|
||||
class TestTruncationHelperMethods:
|
||||
"""Test helper methods used in truncation."""
|
||||
|
||||
@pytest.fixture
|
||||
def truncator(self):
|
||||
return VariableTruncator()
|
||||
|
||||
def test_truncate_item_to_budget_string(self, truncator):
|
||||
"""Test _truncate_item_to_budget with string input."""
|
||||
item = "this is a long string"
|
||||
budget = 15
|
||||
result, was_truncated = truncator._truncate_item_to_budget(item, budget)
|
||||
|
||||
assert isinstance(result, str)
|
||||
# Should be truncated to fit budget
|
||||
if was_truncated:
|
||||
assert len(result) <= budget
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_truncate_item_to_budget_dict(self, truncator):
|
||||
"""Test _truncate_item_to_budget with dict input."""
|
||||
item = {"key": "value", "longer": "longer value"}
|
||||
budget = 30
|
||||
result, was_truncated = truncator._truncate_item_to_budget(item, budget)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
# Should apply object truncation logic
|
||||
|
||||
def test_truncate_item_to_budget_list(self, truncator):
|
||||
"""Test _truncate_item_to_budget with list input."""
|
||||
item = [1, 2, 3, 4, 5]
|
||||
budget = 15
|
||||
result, was_truncated = truncator._truncate_item_to_budget(item, budget)
|
||||
|
||||
assert isinstance(result, list)
|
||||
# Should apply array truncation logic
|
||||
|
||||
def test_truncate_item_to_budget_other_types(self, truncator):
|
||||
"""Test _truncate_item_to_budget with other types."""
|
||||
# Small number that fits
|
||||
result, was_truncated = truncator._truncate_item_to_budget(123, 10)
|
||||
assert result == 123
|
||||
assert was_truncated is False
|
||||
|
||||
# Large number that might not fit - should convert to string if needed
|
||||
large_num = 123456789012345
|
||||
result, was_truncated = truncator._truncate_item_to_budget(large_num, 5)
|
||||
if was_truncated:
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_truncate_value_to_budget_string(self, truncator):
|
||||
"""Test _truncate_value_to_budget with string input."""
|
||||
value = "x" * 100
|
||||
budget = 20
|
||||
result, was_truncated = truncator._truncate_value_to_budget(value, budget)
|
||||
|
||||
assert isinstance(result, str)
|
||||
if was_truncated:
|
||||
assert len(result) <= 20 # Should respect budget
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_truncate_value_to_budget_respects_object_char_limit(self, truncator):
|
||||
"""Test that _truncate_value_to_budget respects OBJECT_CHAR_LIMIT."""
|
||||
# Even with large budget, should respect OBJECT_CHAR_LIMIT
|
||||
large_string = "x" * 10000
|
||||
large_budget = 20000
|
||||
result, was_truncated = truncator._truncate_value_to_budget(large_string, large_budget)
|
||||
|
||||
if was_truncated:
|
||||
assert len(result) <= OBJECT_CHAR_LIMIT + 3 # +3 for "..."
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def test_empty_inputs(self):
|
||||
"""Test truncator with empty inputs."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Empty string
|
||||
result = truncator.truncate(StringSegment(value=""))
|
||||
assert not result.truncated
|
||||
assert result.result.value == ""
|
||||
|
||||
# Empty array
|
||||
from factories.variable_factory import build_segment
|
||||
|
||||
result = truncator.truncate(build_segment([]))
|
||||
assert not result.truncated
|
||||
assert result.result.value == []
|
||||
|
||||
# Empty object
|
||||
result = truncator.truncate(ObjectSegment(value={}))
|
||||
assert not result.truncated
|
||||
assert result.result.value == {}
|
||||
|
||||
def test_zero_and_negative_limits(self):
|
||||
"""Test truncator behavior with zero or very small limits."""
|
||||
# Zero string limit
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(string_length_limit=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(array_element_limit=0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncator = VariableTruncator(max_size_bytes=0)
|
||||
|
||||
def test_unicode_and_special_characters(self):
|
||||
"""Test truncator with unicode and special characters."""
|
||||
truncator = VariableTruncator(string_length_limit=10)
|
||||
|
||||
# Unicode characters
|
||||
unicode_text = "🌍🚀🌍🚀🌍🚀🌍🚀🌍🚀" # Each emoji counts as 1 character
|
||||
result = truncator.truncate(StringSegment(value=unicode_text))
|
||||
if len(unicode_text) > 10:
|
||||
assert result.truncated is True
|
||||
|
||||
# Special JSON characters
|
||||
special_chars = '{"key": "value with \\"quotes\\" and \\n newlines"}'
|
||||
result = truncator.truncate(StringSegment(value=special_chars))
|
||||
assert isinstance(result.result, StringSegment)
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test realistic integration scenarios."""
|
||||
|
||||
def test_workflow_output_scenario(self):
|
||||
"""Test truncation of typical workflow output data."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
workflow_data = {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"users": [
|
||||
{"id": 1, "name": "Alice", "email": "alice@example.com"},
|
||||
{"id": 2, "name": "Bob", "email": "bob@example.com"},
|
||||
]
|
||||
* 3, # Multiply to make it larger
|
||||
"metadata": {
|
||||
"count": 6,
|
||||
"processing_time": "1.23s",
|
||||
"details": "x" * 200, # Long string but not too long
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=workflow_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert isinstance(result.result, (ObjectSegment, StringSegment))
|
||||
# Should handle complex nested structure appropriately
|
||||
|
||||
def test_large_text_processing_scenario(self):
|
||||
"""Test truncation of large text data."""
|
||||
truncator = VariableTruncator(string_length_limit=100)
|
||||
|
||||
large_text = "This is a very long text document. " * 20 # Make it larger than limit
|
||||
|
||||
segment = StringSegment(value=large_text)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
assert result.truncated is True
|
||||
assert isinstance(result.result, StringSegment)
|
||||
assert len(result.result.value) <= 103 # 100 + "..."
|
||||
assert result.result.value.endswith("...")
|
||||
|
||||
def test_mixed_data_types_scenario(self):
|
||||
"""Test truncation with mixed data types in complex structure."""
|
||||
truncator = VariableTruncator(string_length_limit=30, array_element_limit=3, max_size_bytes=300)
|
||||
|
||||
mixed_data = {
|
||||
"strings": ["short", "medium length", "very long string " * 3],
|
||||
"numbers": [1, 2.5, 999999],
|
||||
"booleans": [True, False, True],
|
||||
"nested": {
|
||||
"more_strings": ["nested string " * 2],
|
||||
"more_numbers": list(range(5)),
|
||||
"deep": {"level": 3, "content": "deep content " * 3},
|
||||
},
|
||||
"nulls": [None, None],
|
||||
}
|
||||
|
||||
segment = ObjectSegment(value=mixed_data)
|
||||
result = truncator.truncate(segment)
|
||||
|
||||
assert isinstance(result, TruncationResult)
|
||||
# Should handle all data types appropriately
|
||||
if result.truncated:
|
||||
# Verify the result is smaller or equal than original
|
||||
original_size = truncator.calculate_json_size(mixed_data)
|
||||
if isinstance(result.result, ObjectSegment):
|
||||
result_size = truncator.calculate_json_size(result.result.value)
|
||||
assert result_size <= original_size
|
||||
|
||||
|
||||
class TestConstantsAndConfiguration:
|
||||
"""Test behavior with different configuration constants."""
|
||||
|
||||
def test_large_variable_threshold_constant(self):
|
||||
"""Test that LARGE_VARIABLE_THRESHOLD constant is properly used."""
|
||||
truncator = VariableTruncator()
|
||||
assert truncator._max_size_bytes == LARGE_VARIABLE_THRESHOLD
|
||||
assert LARGE_VARIABLE_THRESHOLD == 10 * 1024 # 10KB
|
||||
|
||||
def test_string_truncation_limit_constant(self):
|
||||
"""Test that STRING_TRUNCATION_LIMIT constant is properly used."""
|
||||
truncator = VariableTruncator()
|
||||
assert truncator._string_length_limit == 5000
|
||||
|
||||
def test_array_char_limit_constant(self):
|
||||
"""Test that ARRAY_CHAR_LIMIT is used in array item truncation."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Test that ARRAY_CHAR_LIMIT is respected in array item truncation
|
||||
long_string = "x" * 2000
|
||||
budget = 5000 # Large budget
|
||||
|
||||
result, was_truncated = truncator._truncate_item_to_budget(long_string, budget)
|
||||
if was_truncated:
|
||||
# Should not exceed ARRAY_CHAR_LIMIT even with large budget
|
||||
assert len(result) <= ARRAY_CHAR_LIMIT + 3 # +3 for "..."
|
||||
|
||||
def test_object_char_limit_constant(self):
|
||||
"""Test that OBJECT_CHAR_LIMIT is used in object value truncation."""
|
||||
truncator = VariableTruncator()
|
||||
|
||||
# Test that OBJECT_CHAR_LIMIT is respected in object value truncation
|
||||
long_string = "x" * 8000
|
||||
large_budget = 20000
|
||||
|
||||
result, was_truncated = truncator._truncate_value_to_budget(long_string, large_budget)
|
||||
if was_truncated:
|
||||
# Should not exceed OBJECT_CHAR_LIMIT even with large budget
|
||||
assert len(result) <= OBJECT_CHAR_LIMIT + 3 # +3 for "..."
|
||||
@@ -0,0 +1,379 @@
|
||||
"""Simplified unit tests for DraftVarLoader focusing on core functionality."""
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.variables.segments import ObjectSegment, StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from models.model import UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from services.workflow_draft_variable_service import DraftVarLoader
|
||||
|
||||
|
||||
class TestDraftVarLoaderSimple:
|
||||
"""Simplified unit tests for DraftVarLoader core methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine(self) -> Engine:
|
||||
return Mock(spec=Engine)
|
||||
|
||||
@pytest.fixture
|
||||
def draft_var_loader(self, mock_engine):
|
||||
"""Create DraftVarLoader instance for testing."""
|
||||
return DraftVarLoader(
|
||||
engine=mock_engine,
|
||||
app_id="test-app-id",
|
||||
tenant_id="test-tenant-id",
|
||||
fallback_variables=[]
|
||||
)
|
||||
|
||||
def test_load_offloaded_variable_string_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with string type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_variable"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_variable"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_content = "This is the full string content"
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_content.encode()
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_variable"
|
||||
mock_variable.value = StringSegment(value=test_content)
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_variable")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_variable"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_content
|
||||
|
||||
# Verify storage was called correctly
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.txt")
|
||||
|
||||
def test_load_offloaded_variable_object_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with object type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.OBJECT
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_object"
|
||||
draft_var.description = "test description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_object"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_object = {"key1": "value1", "key2": 42}
|
||||
test_json_content = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
mock_segment = ObjectSegment(value=test_object)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_object"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_object")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_object"
|
||||
assert variable.description == "test description"
|
||||
assert variable.value == test_object
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object)
|
||||
|
||||
def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when variable_file is None."""
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = None
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_offloaded_variable_missing_upload_file_unit(self, draft_var_loader):
|
||||
"""Test that assertion error is raised when upload_file is None."""
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.upload_file = None
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
def test_load_variables_empty_selectors_unit(self, draft_var_loader):
|
||||
"""Test load_variables returns empty list for empty selectors."""
|
||||
result = draft_var_loader.load_variables([])
|
||||
assert result == []
|
||||
|
||||
def test_selector_to_tuple_unit(self, draft_var_loader):
|
||||
"""Test _selector_to_tuple method."""
|
||||
selector = ["node_id", "var_name", "extra_field"]
|
||||
result = draft_var_loader._selector_to_tuple(selector)
|
||||
assert result == ("node_id", "var_name")
|
||||
|
||||
def test_load_offloaded_variable_number_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with number type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_number.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.NUMBER
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_number"
|
||||
draft_var.description = "test number description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_number"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_number = 123.45
|
||||
test_json_content = json.dumps(test_number)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import FloatSegment
|
||||
mock_segment = FloatSegment(value=test_number)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_number"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_number")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_number"
|
||||
assert variable.description == "test number description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_number.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number)
|
||||
|
||||
def test_load_offloaded_variable_array_type_unit(self, draft_var_loader):
|
||||
"""Test _load_offloaded_variable with array type - isolated unit test."""
|
||||
# Create mock objects
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/test_array.json"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.ARRAY_ANY
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
draft_var.id = "draft-var-id"
|
||||
draft_var.node_id = "test-node-id"
|
||||
draft_var.name = "test_array"
|
||||
draft_var.description = "test array description"
|
||||
draft_var.get_selector.return_value = ["test-node-id", "test_array"]
|
||||
draft_var.variable_file = variable_file
|
||||
|
||||
test_array = ["item1", "item2", "item3"]
|
||||
test_json_content = json.dumps(test_array)
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = test_json_content.encode()
|
||||
|
||||
with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment:
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
mock_segment = ArrayAnySegment(value=test_array)
|
||||
mock_build_segment.return_value = mock_segment
|
||||
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
mock_variable = Mock()
|
||||
mock_variable.id = "draft-var-id"
|
||||
mock_variable.name = "test_array"
|
||||
mock_variable.value = mock_segment
|
||||
mock_segment_to_variable.return_value = mock_variable
|
||||
|
||||
# Execute the method
|
||||
selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var)
|
||||
|
||||
# Verify results
|
||||
assert selector_tuple == ("test-node-id", "test_array")
|
||||
assert variable.id == "draft-var-id"
|
||||
assert variable.name == "test_array"
|
||||
assert variable.description == "test array description"
|
||||
|
||||
# Verify method calls
|
||||
mock_storage.load.assert_called_once_with("storage/key/test_array.json")
|
||||
mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array)
|
||||
|
||||
def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with mix of regular and offloaded variables."""
|
||||
selectors = [
|
||||
["node1", "regular_var"],
|
||||
["node2", "offloaded_var"]
|
||||
]
|
||||
|
||||
# Mock regular variable
|
||||
regular_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
regular_draft_var.is_truncated.return_value = False
|
||||
regular_draft_var.node_id = "node1"
|
||||
regular_draft_var.name = "regular_var"
|
||||
regular_draft_var.get_value.return_value = StringSegment(value="regular_value")
|
||||
regular_draft_var.get_selector.return_value = ["node1", "regular_var"]
|
||||
regular_draft_var.id = "regular-var-id"
|
||||
regular_draft_var.description = "regular description"
|
||||
|
||||
# Mock offloaded variable
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.key = "storage/key/offloaded.txt"
|
||||
|
||||
variable_file = Mock(spec=WorkflowDraftVariableFile)
|
||||
variable_file.value_type = SegmentType.STRING
|
||||
variable_file.upload_file = upload_file
|
||||
|
||||
offloaded_draft_var = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_draft_var.is_truncated.return_value = True
|
||||
offloaded_draft_var.node_id = "node2"
|
||||
offloaded_draft_var.name = "offloaded_var"
|
||||
offloaded_draft_var.get_selector.return_value = ["node2", "offloaded_var"]
|
||||
offloaded_draft_var.variable_file = variable_file
|
||||
offloaded_draft_var.id = "offloaded-var-id"
|
||||
offloaded_draft_var.description = "offloaded description"
|
||||
|
||||
draft_vars = [regular_draft_var, offloaded_draft_var]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch("services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable:
|
||||
# Mock regular variable creation
|
||||
regular_variable = Mock()
|
||||
regular_variable.selector = ["node1", "regular_var"]
|
||||
|
||||
# Mock offloaded variable creation
|
||||
offloaded_variable = Mock()
|
||||
offloaded_variable.selector = ["node2", "offloaded_var"]
|
||||
|
||||
mock_segment_to_variable.return_value = regular_variable
|
||||
|
||||
with patch("services.workflow_draft_variable_service.storage") as mock_storage:
|
||||
mock_storage.load.return_value = b"offloaded_content"
|
||||
|
||||
with patch.object(draft_var_loader, "_load_offloaded_variable") as mock_load_offloaded:
|
||||
mock_load_offloaded.return_value = (("node2", "offloaded_var"), offloaded_variable)
|
||||
|
||||
with patch("concurrent.futures.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [(("node2", "offloaded_var"), offloaded_variable)]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify service method was called
|
||||
mock_service.get_draft_variables_by_selectors.assert_called_once_with(
|
||||
draft_var_loader._app_id, selectors
|
||||
)
|
||||
|
||||
# Verify offloaded variable loading was called
|
||||
mock_load_offloaded.assert_called_once_with(offloaded_draft_var)
|
||||
|
||||
def test_load_variables_all_offloaded_variables_unit(self, draft_var_loader):
|
||||
"""Test load_variables method with only offloaded variables."""
|
||||
selectors = [
|
||||
["node1", "offloaded_var1"],
|
||||
["node2", "offloaded_var2"]
|
||||
]
|
||||
|
||||
# Mock first offloaded variable
|
||||
offloaded_var1 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var1.is_truncated.return_value = True
|
||||
offloaded_var1.node_id = "node1"
|
||||
offloaded_var1.name = "offloaded_var1"
|
||||
|
||||
# Mock second offloaded variable
|
||||
offloaded_var2 = Mock(spec=WorkflowDraftVariable)
|
||||
offloaded_var2.is_truncated.return_value = True
|
||||
offloaded_var2.node_id = "node2"
|
||||
offloaded_var2.name = "offloaded_var2"
|
||||
|
||||
draft_vars = [offloaded_var1, offloaded_var2]
|
||||
|
||||
with patch("services.workflow_draft_variable_service.Session") as mock_session_cls:
|
||||
mock_session = Mock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_draft_variables_by_selectors.return_value = draft_vars
|
||||
|
||||
with patch("services.workflow_draft_variable_service.WorkflowDraftVariableService", return_value=mock_service):
|
||||
with patch("services.workflow_draft_variable_service.StorageKeyLoader"):
|
||||
with patch("services.workflow_draft_variable_service.ThreadPoolExecutor") as mock_executor_cls:
|
||||
mock_executor = Mock()
|
||||
mock_executor_cls.return_value.__enter__.return_value = mock_executor
|
||||
mock_executor.map.return_value = [
|
||||
(("node1", "offloaded_var1"), Mock()),
|
||||
(("node2", "offloaded_var2"), Mock())
|
||||
]
|
||||
|
||||
# Execute the method
|
||||
result = draft_var_loader.load_variables(selectors)
|
||||
|
||||
# Verify results - since we have only offloaded variables, should have 2 results
|
||||
assert len(result) == 2
|
||||
|
||||
# Verify ThreadPoolExecutor was used
|
||||
mock_executor_cls.assert_called_once_with(max_workers=10)
|
||||
mock_executor.map.assert_called_once()
|
||||
@@ -1,16 +1,26 @@
|
||||
import dataclasses
|
||||
import secrets
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables import StringSegment
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account
|
||||
from models.enums import DraftVariableType
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowDraftVariable,
|
||||
WorkflowDraftVariableFile,
|
||||
WorkflowNodeExecutionModel,
|
||||
is_system_variable_editable,
|
||||
)
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
VariableResetError,
|
||||
@@ -37,6 +47,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test__should_variable_be_visible(self):
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = Account(id=str(uuid.uuid4()))
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -44,6 +55,7 @@ class TestDraftVariableSaver:
|
||||
node_id="test_node_id",
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False
|
||||
assert saver._should_variable_be_visible("123", NodeType.START, "output") == True
|
||||
@@ -83,6 +95,7 @@ class TestDraftVariableSaver:
|
||||
]
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_user = MagicMock()
|
||||
test_app_id = self._get_test_app_id()
|
||||
saver = DraftVariableSaver(
|
||||
session=mock_session,
|
||||
@@ -90,6 +103,7 @@ class TestDraftVariableSaver:
|
||||
node_id=_NODE_ID,
|
||||
node_type=NodeType.START,
|
||||
node_execution_id="test_execution_id",
|
||||
user=mock_user,
|
||||
)
|
||||
for idx, c in enumerate(cases, 1):
|
||||
fail_msg = f"Test case {c.name} failed, index={idx}"
|
||||
@@ -97,6 +111,76 @@ class TestDraftVariableSaver:
|
||||
assert node_id == c.expected_node_id, fail_msg
|
||||
assert name == c.expected_name, fail_msg
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Mock SQLAlchemy session."""
|
||||
from sqlalchemy import Engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_session.get_bind.return_value = mock_engine
|
||||
return mock_session
|
||||
|
||||
@pytest.fixture
|
||||
def draft_saver(self, mock_session):
|
||||
"""Create DraftVariableSaver instance with user context."""
|
||||
# Create a mock user
|
||||
mock_user = MagicMock(spec=Account)
|
||||
mock_user.id = "test-user-id"
|
||||
mock_user.tenant_id = "test-tenant-id"
|
||||
|
||||
return DraftVariableSaver(
|
||||
session=mock_session,
|
||||
app_id="test-app-id",
|
||||
node_id="test-node-id",
|
||||
node_type=NodeType.LLM,
|
||||
node_execution_id="test-execution-id",
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id is None
|
||||
_mock_try_offload.return_value = None
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
id=str(uuidv7()),
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.ARRAY_STRING,
|
||||
upload_file_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
_mock_try_offload.return_value = mock_segment, mock_draft_var_file
|
||||
draft_var = draft_saver._create_draft_variable(name="small_var", value=mock_segment, visible=True)
|
||||
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
|
||||
draft_saver.save(outputs=outputs)
|
||||
|
||||
# Should batch upsert draft variables
|
||||
mock_batch_upsert.assert_called_once()
|
||||
draft_vars = mock_batch_upsert.call_args[0][1]
|
||||
assert len(draft_vars) == 2
|
||||
|
||||
|
||||
class TestWorkflowDraftVariableService:
|
||||
def _get_test_app_id(self):
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
from unittest.mock import ANY, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sa
|
||||
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
from tasks.remove_app_and_related_data_task import (
|
||||
_delete_draft_variable_offload_data,
|
||||
_delete_draft_variables,
|
||||
delete_draft_variables_batch,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesBatch:
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_success(self, mock_db):
|
||||
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
|
||||
"""Test successful deletion of draft variables in batches."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 100
|
||||
@@ -24,13 +28,19 @@ class TestDeleteDraftVariablesBatch:
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
|
||||
# Mock two batches of results, then empty
|
||||
batch1_ids = [f"var-{i}" for i in range(100)]
|
||||
batch2_ids = [f"var-{i}" for i in range(100, 150)]
|
||||
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
|
||||
batch2_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(100, 150)]
|
||||
|
||||
batch1_ids = [row[0] for row in batch1_data]
|
||||
batch1_file_ids = [row[1] for row in batch1_data if row[1] is not None]
|
||||
|
||||
batch2_ids = [row[0] for row in batch2_data]
|
||||
batch2_file_ids = [row[1] for row in batch2_data if row[1] is not None]
|
||||
|
||||
# Setup side effects for execute calls in the correct order:
|
||||
# 1. SELECT (returns batch1_ids)
|
||||
# 1. SELECT (returns batch1_data with id, file_id)
|
||||
# 2. DELETE (returns result with rowcount=100)
|
||||
# 3. SELECT (returns batch2_ids)
|
||||
# 3. SELECT (returns batch2_data)
|
||||
# 4. DELETE (returns result with rowcount=50)
|
||||
# 5. SELECT (returns empty, ends loop)
|
||||
|
||||
@@ -41,14 +51,14 @@ class TestDeleteDraftVariablesBatch:
|
||||
|
||||
# First SELECT result
|
||||
select_result1 = MagicMock()
|
||||
select_result1.__iter__.return_value = iter([(id_,) for id_ in batch1_ids])
|
||||
select_result1.__iter__.return_value = iter(batch1_data)
|
||||
|
||||
# First DELETE result
|
||||
delete_result1 = MockResult(rowcount=100)
|
||||
|
||||
# Second SELECT result
|
||||
select_result2 = MagicMock()
|
||||
select_result2.__iter__.return_value = iter([(id_,) for id_ in batch2_ids])
|
||||
select_result2.__iter__.return_value = iter(batch2_data)
|
||||
|
||||
# Second DELETE result
|
||||
delete_result2 = MockResult(rowcount=50)
|
||||
@@ -66,6 +76,9 @@ class TestDeleteDraftVariablesBatch:
|
||||
select_result3, # Third SELECT (empty)
|
||||
]
|
||||
|
||||
# Mock offload data cleanup
|
||||
mock_offload_cleanup.side_effect = [len(batch1_file_ids), len(batch2_file_ids)]
|
||||
|
||||
# Execute the function
|
||||
result = delete_draft_variables_batch(app_id, batch_size)
|
||||
|
||||
@@ -75,65 +88,18 @@ class TestDeleteDraftVariablesBatch:
|
||||
# Verify database calls
|
||||
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
|
||||
|
||||
# Verify the expected calls in order:
|
||||
# 1. SELECT, 2. DELETE, 3. SELECT, 4. DELETE, 5. SELECT
|
||||
expected_calls = [
|
||||
# First SELECT
|
||||
call(
|
||||
sa.text("""
|
||||
SELECT id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""),
|
||||
{"app_id": app_id, "batch_size": batch_size},
|
||||
),
|
||||
# First DELETE
|
||||
call(
|
||||
sa.text("""
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
"""),
|
||||
{"ids": tuple(batch1_ids)},
|
||||
),
|
||||
# Second SELECT
|
||||
call(
|
||||
sa.text("""
|
||||
SELECT id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""),
|
||||
{"app_id": app_id, "batch_size": batch_size},
|
||||
),
|
||||
# Second DELETE
|
||||
call(
|
||||
sa.text("""
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
"""),
|
||||
{"ids": tuple(batch2_ids)},
|
||||
),
|
||||
# Third SELECT (empty result)
|
||||
call(
|
||||
sa.text("""
|
||||
SELECT id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""),
|
||||
{"app_id": app_id, "batch_size": batch_size},
|
||||
),
|
||||
]
|
||||
# Verify offload cleanup was called for both batches with file_ids
|
||||
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
|
||||
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
|
||||
|
||||
# Check that all calls were made correctly
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
assert len(actual_calls) == len(expected_calls)
|
||||
|
||||
# Simplified verification - just check that the right number of calls were made
|
||||
# Simplified verification - check that the right number of calls were made
|
||||
# and that the SQL queries contain the expected patterns
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
for i, actual_call in enumerate(actual_calls):
|
||||
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
|
||||
# Verify it's a SELECT query
|
||||
# Verify it's a SELECT query that now includes file_id
|
||||
sql_text = str(actual_call[0][0])
|
||||
assert "SELECT id FROM workflow_draft_variables" in sql_text
|
||||
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE app_id = :app_id" in sql_text
|
||||
assert "LIMIT :batch_size" in sql_text
|
||||
else: # DELETE calls (odd indices: 1, 3)
|
||||
@@ -142,8 +108,9 @@ class TestDeleteDraftVariablesBatch:
|
||||
assert "DELETE FROM workflow_draft_variables" in sql_text
|
||||
assert "WHERE id IN :ids" in sql_text
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_db):
|
||||
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
|
||||
"""Test deletion when no draft variables exist for the app."""
|
||||
app_id = "nonexistent-app-id"
|
||||
batch_size = 1000
|
||||
@@ -167,6 +134,7 @@ class TestDeleteDraftVariablesBatch:
|
||||
|
||||
assert result == 0
|
||||
assert mock_conn.execute.call_count == 1 # Only one select query
|
||||
mock_offload_cleanup.assert_not_called() # No files to clean up
|
||||
|
||||
def test_delete_draft_variables_batch_invalid_batch_size(self):
|
||||
"""Test that invalid batch size raises ValueError."""
|
||||
@@ -178,9 +146,10 @@ class TestDeleteDraftVariablesBatch:
|
||||
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||
delete_draft_variables_batch(app_id, 0)
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
|
||||
@patch("tasks.remove_app_and_related_data_task.db")
|
||||
@patch("tasks.remove_app_and_related_data_task.logger")
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db):
|
||||
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
|
||||
"""Test that batch deletion logs progress correctly."""
|
||||
app_id = "test-app-id"
|
||||
batch_size = 50
|
||||
@@ -196,10 +165,13 @@ class TestDeleteDraftVariablesBatch:
|
||||
mock_engine.begin.return_value = mock_context_manager
|
||||
|
||||
# Mock one batch then empty
|
||||
batch_ids = [f"var-{i}" for i in range(30)]
|
||||
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
|
||||
batch_ids = [row[0] for row in batch_data]
|
||||
batch_file_ids = [row[1] for row in batch_data if row[1] is not None]
|
||||
|
||||
# Create properly configured mocks
|
||||
select_result = MagicMock()
|
||||
select_result.__iter__.return_value = iter([(id_,) for id_ in batch_ids])
|
||||
select_result.__iter__.return_value = iter(batch_data)
|
||||
|
||||
# Create simple object with rowcount attribute
|
||||
class MockResult:
|
||||
@@ -220,10 +192,17 @@ class TestDeleteDraftVariablesBatch:
|
||||
empty_result,
|
||||
]
|
||||
|
||||
# Mock offload cleanup
|
||||
mock_offload_cleanup.return_value = len(batch_file_ids)
|
||||
|
||||
result = delete_draft_variables_batch(app_id, batch_size)
|
||||
|
||||
assert result == 30
|
||||
|
||||
# Verify offload cleanup was called with file_ids
|
||||
if batch_file_ids:
|
||||
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
|
||||
|
||||
# Verify logging calls
|
||||
assert mock_logging.info.call_count == 2
|
||||
mock_logging.info.assert_any_call(
|
||||
@@ -241,3 +220,118 @@ class TestDeleteDraftVariablesBatch:
|
||||
|
||||
assert result == expected_return
|
||||
mock_batch_delete.assert_called_once_with(app_id, batch_size=1000)
|
||||
|
||||
|
||||
class TestDeleteDraftVariableOffloadData:
|
||||
"""Test the Offload data cleanup functionality."""
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variable_offload_data_success(self, mock_storage):
|
||||
"""Test successful deletion of offload data."""
|
||||
|
||||
# Mock connection
|
||||
mock_conn = MagicMock()
|
||||
file_ids = ["file-1", "file-2", "file-3"]
|
||||
|
||||
# Mock query results: (variable_file_id, storage_key, upload_file_id)
|
||||
query_results = [
|
||||
("file-1", "storage/key/1", "upload-1"),
|
||||
("file-2", "storage/key/2", "upload-2"),
|
||||
("file-3", "storage/key/3", "upload-3"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.__iter__.return_value = iter(query_results)
|
||||
mock_conn.execute.return_value = mock_result
|
||||
|
||||
# Execute function
|
||||
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
|
||||
|
||||
# Verify return value
|
||||
assert result == 3
|
||||
|
||||
# Verify storage deletion calls
|
||||
expected_storage_calls = [call("storage/key/1"), call("storage/key/2"), call("storage/key/3")]
|
||||
mock_storage.delete.assert_has_calls(expected_storage_calls, any_order=True)
|
||||
|
||||
# Verify database calls - should be 3 calls total
|
||||
assert mock_conn.execute.call_count == 3
|
||||
|
||||
# Verify the queries were called
|
||||
actual_calls = mock_conn.execute.call_args_list
|
||||
|
||||
# First call should be the SELECT query
|
||||
select_call_sql = str(actual_calls[0][0][0])
|
||||
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
|
||||
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
|
||||
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
|
||||
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
|
||||
|
||||
# Second call should be DELETE upload_files
|
||||
delete_upload_call_sql = str(actual_calls[1][0][0])
|
||||
assert "DELETE FROM upload_files" in delete_upload_call_sql
|
||||
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
|
||||
|
||||
# Third call should be DELETE workflow_draft_variable_files
|
||||
delete_variable_files_call_sql = str(actual_calls[2][0][0])
|
||||
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
|
||||
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql
|
||||
|
||||
def test_delete_draft_variable_offload_data_empty_file_ids(self):
|
||||
"""Test handling of empty file_ids list."""
|
||||
mock_conn = MagicMock()
|
||||
|
||||
result = _delete_draft_variable_offload_data(mock_conn, [])
|
||||
|
||||
assert result == 0
|
||||
mock_conn.execute.assert_not_called()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
def test_delete_draft_variable_offload_data_storage_failure(self, mock_logging, mock_storage):
|
||||
"""Test handling of storage deletion failures."""
|
||||
mock_conn = MagicMock()
|
||||
file_ids = ["file-1", "file-2"]
|
||||
|
||||
# Mock query results
|
||||
query_results = [
|
||||
("file-1", "storage/key/1", "upload-1"),
|
||||
("file-2", "storage/key/2", "upload-2"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.__iter__.return_value = iter(query_results)
|
||||
mock_conn.execute.return_value = mock_result
|
||||
|
||||
# Make storage.delete fail for the first file
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Execute function
|
||||
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
|
||||
|
||||
# Should still return 2 (both files processed, even if one storage delete failed)
|
||||
assert result == 1 # Only one storage deletion succeeded
|
||||
|
||||
# Verify warning was logged
|
||||
mock_logging.warning.assert_called_once_with("Failed to delete storage object storage/key/1: Storage error")
|
||||
|
||||
# Verify both database cleanup calls still happened
|
||||
assert mock_conn.execute.call_count == 3
|
||||
|
||||
@patch("tasks.remove_app_and_related_data_task.logging")
|
||||
def test_delete_draft_variable_offload_data_database_failure(self, mock_logging):
|
||||
"""Test handling of database operation failures."""
|
||||
mock_conn = MagicMock()
|
||||
file_ids = ["file-1"]
|
||||
|
||||
# Make execute raise an exception
|
||||
mock_conn.execute.side_effect = Exception("Database error")
|
||||
|
||||
# Execute function - should not raise, but log error
|
||||
result = _delete_draft_variable_offload_data(mock_conn, file_ids)
|
||||
|
||||
# Should return 0 when error occurs
|
||||
assert result == 0
|
||||
|
||||
# Verify error was logged
|
||||
mock_logging.error.assert_called_once_with("Error deleting draft variable offload data: Database error")
|
||||
|
||||
Reference in New Issue
Block a user