mirror of
https://github.com/langgenius/dify.git
synced 2026-06-03 08:16:37 +08:00
refactor(tests): use db_session_with_containers in test_storage_key_loader (#35766)
Co-authored-by: yeranyang <yeranyang@tencent.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
+241
-228
@@ -1,4 +1,5 @@
|
|||||||
import unittest
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
@@ -16,7 +17,7 @@ from models.enums import CreatorUserRole
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||||
class TestStorageKeyLoader(unittest.TestCase):
|
class TestStorageKeyLoader:
|
||||||
"""
|
"""
|
||||||
Integration tests for StorageKeyLoader class.
|
Integration tests for StorageKeyLoader class.
|
||||||
|
|
||||||
@@ -24,110 +25,82 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||||||
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
with different transfer methods: LOCAL_FILE, REMOTE_URL, and TOOL_FILE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
# ------------------------------------------------------------------
|
||||||
"""Set up test data before each test method."""
|
# Per-test helpers (use db_session_with_containers as parameter)
|
||||||
self.session = db.session()
|
# ------------------------------------------------------------------
|
||||||
self.tenant_id = str(uuid4())
|
|
||||||
self.user_id = str(uuid4())
|
|
||||||
self.conversation_id = str(uuid4())
|
|
||||||
|
|
||||||
# Create test data that will be cleaned up after each test
|
|
||||||
self.test_upload_files = []
|
|
||||||
self.test_tool_files = []
|
|
||||||
|
|
||||||
# Create StorageKeyLoader instance
|
|
||||||
self.loader = StorageKeyLoader(
|
|
||||||
self.session,
|
|
||||||
self.tenant_id,
|
|
||||||
access_controller=DatabaseFileAccessController(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""Clean up test data after each test method."""
|
|
||||||
self.session.rollback()
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_upload_file(
|
def _create_upload_file(
|
||||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
session: Session,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
file_id: str | None = None,
|
||||||
|
storage_key: str | None = None,
|
||||||
|
override_tenant_id: str | None = None,
|
||||||
) -> UploadFile:
|
) -> UploadFile:
|
||||||
"""Helper method to create an UploadFile record for testing."""
|
"""Create and flush an UploadFile record for testing."""
|
||||||
if file_id is None:
|
|
||||||
file_id = str(uuid4())
|
|
||||||
if storage_key is None:
|
|
||||||
storage_key = f"test_storage_key_{uuid4()}"
|
|
||||||
if tenant_id is None:
|
|
||||||
tenant_id = self.tenant_id
|
|
||||||
|
|
||||||
upload_file = UploadFile(
|
upload_file = UploadFile(
|
||||||
tenant_id=tenant_id,
|
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||||
storage_type=StorageType.LOCAL,
|
storage_type=StorageType.LOCAL,
|
||||||
key=storage_key,
|
key=storage_key or f"test_storage_key_{uuid4()}",
|
||||||
name="test_file.txt",
|
name="test_file.txt",
|
||||||
size=1024,
|
size=1024,
|
||||||
extension=".txt",
|
extension=".txt",
|
||||||
mime_type="text/plain",
|
mime_type="text/plain",
|
||||||
created_by_role=CreatorUserRole.ACCOUNT,
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
created_by=self.user_id,
|
created_by=user_id,
|
||||||
created_at=datetime.now(UTC),
|
created_at=datetime.now(UTC),
|
||||||
used=False,
|
used=False,
|
||||||
)
|
)
|
||||||
upload_file.id = file_id
|
upload_file.id = file_id or str(uuid4())
|
||||||
|
session.add(upload_file)
|
||||||
self.session.add(upload_file)
|
session.flush()
|
||||||
self.session.flush()
|
|
||||||
self.test_upload_files.append(upload_file)
|
|
||||||
|
|
||||||
return upload_file
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_tool_file(
|
def _create_tool_file(
|
||||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
session: Session,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
*,
|
||||||
|
file_id: str | None = None,
|
||||||
|
file_key: str | None = None,
|
||||||
|
override_tenant_id: str | None = None,
|
||||||
) -> ToolFile:
|
) -> ToolFile:
|
||||||
"""Helper method to create a ToolFile record for testing."""
|
"""Create and flush a ToolFile record for testing."""
|
||||||
if file_id is None:
|
|
||||||
file_id = str(uuid4())
|
|
||||||
if file_key is None:
|
|
||||||
file_key = f"test_file_key_{uuid4()}"
|
|
||||||
if tenant_id is None:
|
|
||||||
tenant_id = self.tenant_id
|
|
||||||
|
|
||||||
tool_file = ToolFile(
|
tool_file = ToolFile(
|
||||||
user_id=self.user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||||
conversation_id=self.conversation_id,
|
conversation_id=conversation_id,
|
||||||
file_key=file_key,
|
file_key=file_key or f"test_file_key_{uuid4()}",
|
||||||
mimetype="text/plain",
|
mimetype="text/plain",
|
||||||
original_url="http://example.com/file.txt",
|
original_url="http://example.com/file.txt",
|
||||||
name="test_tool_file.txt",
|
name="test_tool_file.txt",
|
||||||
size=2048,
|
size=2048,
|
||||||
)
|
)
|
||||||
tool_file.id = file_id
|
tool_file.id = file_id or str(uuid4())
|
||||||
|
session.add(tool_file)
|
||||||
self.session.add(tool_file)
|
session.flush()
|
||||||
self.session.flush()
|
|
||||||
self.test_tool_files.append(tool_file)
|
|
||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
@staticmethod
|
||||||
"""Helper method to create a File object for testing."""
|
def _create_file(
|
||||||
if tenant_id is None:
|
tenant_id: str,
|
||||||
tenant_id = self.tenant_id
|
related_id: str,
|
||||||
|
transfer_method: FileTransferMethod,
|
||||||
# Set related_id for LOCAL_FILE and TOOL_FILE transfer methods
|
*,
|
||||||
file_related_id = None
|
override_tenant_id: str | None = None,
|
||||||
remote_url = None
|
) -> File:
|
||||||
|
"""Build a File value-object for testing."""
|
||||||
if transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE):
|
remote_url = "https://example.com/test_file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None
|
||||||
file_related_id = related_id
|
|
||||||
elif transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
remote_url = "https://example.com/test_file.txt"
|
|
||||||
file_related_id = related_id
|
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
file_id=str(uuid4()), # Generate new UUID for File.id
|
file_id=str(uuid4()),
|
||||||
tenant_id=tenant_id,
|
tenant_id=override_tenant_id if override_tenant_id is not None else tenant_id,
|
||||||
file_type=FileType.DOCUMENT,
|
file_type=FileType.DOCUMENT,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
related_id=file_related_id,
|
related_id=related_id,
|
||||||
remote_url=remote_url,
|
remote_url=remote_url,
|
||||||
filename="test_file.txt",
|
filename="test_file.txt",
|
||||||
extension=".txt",
|
extension=".txt",
|
||||||
@@ -136,240 +109,280 @@ class TestStorageKeyLoader(unittest.TestCase):
|
|||||||
storage_key="initial_key",
|
storage_key="initial_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_load_storage_keys_local_file(self):
|
# ------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_load_storage_keys_local_file(self, db_session_with_containers: Session):
|
||||||
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
"""Test loading storage keys for LOCAL_FILE transfer method."""
|
||||||
# Create test data
|
tenant_id = str(uuid4())
|
||||||
upload_file = self._create_upload_file()
|
user_id = str(uuid4())
|
||||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
|
||||||
|
|
||||||
# Load storage keys
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
self.loader.load_storage_keys([file])
|
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
|
|
||||||
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
# Verify storage key was loaded correctly
|
|
||||||
assert file._storage_key == upload_file.key
|
assert file._storage_key == upload_file.key
|
||||||
|
|
||||||
def test_load_storage_keys_remote_url(self):
|
def test_load_storage_keys_remote_url(self, db_session_with_containers: Session):
|
||||||
"""Test loading storage keys for REMOTE_URL transfer method."""
|
"""Test loading storage keys for REMOTE_URL transfer method."""
|
||||||
# Create test data
|
tenant_id = str(uuid4())
|
||||||
upload_file = self._create_upload_file()
|
user_id = str(uuid4())
|
||||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
|
||||||
|
|
||||||
# Load storage keys
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
self.loader.load_storage_keys([file])
|
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||||
|
|
||||||
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
# Verify storage key was loaded correctly
|
|
||||||
assert file._storage_key == upload_file.key
|
assert file._storage_key == upload_file.key
|
||||||
|
|
||||||
def test_load_storage_keys_tool_file(self):
|
def test_load_storage_keys_tool_file(self, db_session_with_containers: Session):
|
||||||
"""Test loading storage keys for TOOL_FILE transfer method."""
|
"""Test loading storage keys for TOOL_FILE transfer method."""
|
||||||
# Create test data
|
tenant_id = str(uuid4())
|
||||||
tool_file = self._create_tool_file()
|
user_id = str(uuid4())
|
||||||
file = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
conversation_id = str(uuid4())
|
||||||
|
|
||||||
# Load storage keys
|
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
|
||||||
self.loader.load_storage_keys([file])
|
file = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||||
|
|
||||||
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
# Verify storage key was loaded correctly
|
|
||||||
assert file._storage_key == tool_file.file_key
|
assert file._storage_key == tool_file.file_key
|
||||||
|
|
||||||
def test_load_storage_keys_mixed_methods(self):
|
def test_load_storage_keys_mixed_methods(self, db_session_with_containers: Session):
|
||||||
"""Test batch loading with mixed transfer methods."""
|
"""Test batch loading with mixed transfer methods."""
|
||||||
# Create test data for different transfer methods
|
tenant_id = str(uuid4())
|
||||||
upload_file1 = self._create_upload_file()
|
user_id = str(uuid4())
|
||||||
upload_file2 = self._create_upload_file()
|
conversation_id = str(uuid4())
|
||||||
tool_file = self._create_tool_file()
|
|
||||||
|
|
||||||
file1 = self._create_file(related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
upload_file1 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
file2 = self._create_file(related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
upload_file2 = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
file3 = self._create_file(related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
tool_file = self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id)
|
||||||
|
|
||||||
files = [file1, file2, file3]
|
file1 = self._create_file(tenant_id, related_id=upload_file1.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
|
file2 = self._create_file(tenant_id, related_id=upload_file2.id, transfer_method=FileTransferMethod.REMOTE_URL)
|
||||||
|
file3 = self._create_file(tenant_id, related_id=tool_file.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||||
|
|
||||||
# Load storage keys
|
loader = StorageKeyLoader(
|
||||||
self.loader.load_storage_keys(files)
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([file1, file2, file3])
|
||||||
|
|
||||||
# Verify all storage keys were loaded correctly
|
|
||||||
assert file1._storage_key == upload_file1.key
|
assert file1._storage_key == upload_file1.key
|
||||||
assert file2._storage_key == upload_file2.key
|
assert file2._storage_key == upload_file2.key
|
||||||
assert file3._storage_key == tool_file.file_key
|
assert file3._storage_key == tool_file.file_key
|
||||||
|
|
||||||
def test_load_storage_keys_empty_list(self):
|
def test_load_storage_keys_empty_list(self, db_session_with_containers: Session):
|
||||||
"""Test with empty file list."""
|
"""Test with empty file list — should not raise."""
|
||||||
# Should not raise any exceptions
|
tenant_id = str(uuid4())
|
||||||
self.loader.load_storage_keys([])
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([])
|
||||||
|
|
||||||
def test_load_storage_keys_ignores_legacy_file_tenant_id(self):
|
def test_load_storage_keys_ignores_legacy_file_tenant_id(self, db_session_with_containers: Session):
|
||||||
"""Legacy file tenant_id should not override the loader tenant scope."""
|
"""Legacy file tenant_id should not override the loader tenant scope."""
|
||||||
upload_file = self._create_upload_file()
|
tenant_id = str(uuid4())
|
||||||
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
file = self._create_file(
|
file = self._create_file(
|
||||||
related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4())
|
tenant_id,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
override_tenant_id=str(uuid4()),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.loader.load_storage_keys([file])
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
assert file._storage_key == upload_file.key
|
assert file._storage_key == upload_file.key
|
||||||
|
|
||||||
def test_load_storage_keys_missing_file_id(self):
|
def test_load_storage_keys_missing_file_id(self, db_session_with_containers: Session):
|
||||||
"""Test with None file.related_id."""
|
"""Test with None file.related_id — should raise ValueError."""
|
||||||
# Create a file with valid parameters first, then manually set related_id to None
|
tenant_id = str(uuid4())
|
||||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
|
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
file.related_id = None
|
file.related_id = None
|
||||||
|
|
||||||
# Should raise ValueError for None file related_id
|
loader = StorageKeyLoader(
|
||||||
with pytest.raises(ValueError) as context:
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
self.loader.load_storage_keys([file])
|
)
|
||||||
|
with pytest.raises(ValueError, match="file id should not be None."):
|
||||||
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
assert str(context.value) == "file id should not be None."
|
def test_load_storage_keys_nonexistent_upload_file_records(self, db_session_with_containers: Session):
|
||||||
|
"""Test with missing UploadFile database records — should raise ValueError."""
|
||||||
|
tenant_id = str(uuid4())
|
||||||
|
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
|
|
||||||
def test_load_storage_keys_nonexistent_upload_file_records(self):
|
loader = StorageKeyLoader(
|
||||||
"""Test with missing UploadFile database records."""
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
# Create file with non-existent upload file id
|
)
|
||||||
non_existent_id = str(uuid4())
|
|
||||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
|
||||||
|
|
||||||
# Should raise ValueError for missing record
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
self.loader.load_storage_keys([file])
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
def test_load_storage_keys_nonexistent_tool_file_records(self):
|
def test_load_storage_keys_nonexistent_tool_file_records(self, db_session_with_containers: Session):
|
||||||
"""Test with missing ToolFile database records."""
|
"""Test with missing ToolFile database records — should raise ValueError."""
|
||||||
# Create file with non-existent tool file id
|
tenant_id = str(uuid4())
|
||||||
non_existent_id = str(uuid4())
|
file = self._create_file(tenant_id, related_id=str(uuid4()), transfer_method=FileTransferMethod.TOOL_FILE)
|
||||||
file = self._create_file(related_id=non_existent_id, transfer_method=FileTransferMethod.TOOL_FILE)
|
|
||||||
|
|
||||||
# Should raise ValueError for missing record
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
self.loader.load_storage_keys([file])
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
def test_load_storage_keys_invalid_uuid(self):
|
def test_load_storage_keys_invalid_uuid(self, db_session_with_containers: Session):
|
||||||
"""Test with invalid UUID format."""
|
"""Test with invalid UUID format — should raise ValueError."""
|
||||||
# Create a file with valid parameters first, then manually set invalid related_id
|
tenant_id = str(uuid4())
|
||||||
file = self._create_file(related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE)
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
|
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
file.related_id = "invalid-uuid-format"
|
file.related_id = "invalid-uuid-format"
|
||||||
|
|
||||||
# Should raise ValueError for invalid UUID
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
self.loader.load_storage_keys([file])
|
loader.load_storage_keys([file])
|
||||||
|
|
||||||
def test_load_storage_keys_batch_efficiency(self):
|
def test_load_storage_keys_batch_efficiency(self, db_session_with_containers: Session):
|
||||||
"""Test batched operations use efficient queries."""
|
"""Batched operations should issue exactly 2 queries for mixed file types."""
|
||||||
# Create multiple files of different types
|
tenant_id = str(uuid4())
|
||||||
upload_files = [self._create_upload_file() for _ in range(3)]
|
user_id = str(uuid4())
|
||||||
tool_files = [self._create_tool_file() for _ in range(2)]
|
conversation_id = str(uuid4())
|
||||||
|
|
||||||
files = []
|
upload_files = [self._create_upload_file(db_session_with_containers, tenant_id, user_id) for _ in range(3)]
|
||||||
files.extend(
|
tool_files = [
|
||||||
[self._create_file(related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE) for uf in upload_files]
|
self._create_tool_file(db_session_with_containers, tenant_id, user_id, conversation_id) for _ in range(2)
|
||||||
|
]
|
||||||
|
|
||||||
|
files = [
|
||||||
|
self._create_file(tenant_id, related_id=uf.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
|
for uf in upload_files
|
||||||
|
] + [
|
||||||
|
self._create_file(tenant_id, related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE)
|
||||||
|
for tf in tool_files
|
||||||
|
]
|
||||||
|
|
||||||
|
loader = StorageKeyLoader(
|
||||||
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
)
|
)
|
||||||
files.extend(
|
with patch.object(
|
||||||
[self._create_file(related_id=tf.id, transfer_method=FileTransferMethod.TOOL_FILE) for tf in tool_files]
|
db_session_with_containers, "scalars", wraps=db_session_with_containers.scalars
|
||||||
)
|
) as mock_scalars:
|
||||||
|
loader.load_storage_keys(files)
|
||||||
# Mock the session to count queries
|
# Exactly 2 DB round-trips: one for UploadFile, one for ToolFile
|
||||||
with patch.object(self.session, "scalars", wraps=self.session.scalars) as mock_scalars:
|
|
||||||
self.loader.load_storage_keys(files)
|
|
||||||
|
|
||||||
# Should make exactly 2 queries (one for upload_files, one for tool_files)
|
|
||||||
assert mock_scalars.call_count == 2
|
assert mock_scalars.call_count == 2
|
||||||
|
|
||||||
# Verify all storage keys were loaded correctly
|
|
||||||
for i, file in enumerate(files[:3]):
|
for i, file in enumerate(files[:3]):
|
||||||
assert file._storage_key == upload_files[i].key
|
assert file._storage_key == upload_files[i].key
|
||||||
for i, file in enumerate(files[3:]):
|
for i, file in enumerate(files[3:]):
|
||||||
assert file._storage_key == tool_files[i].file_key
|
assert file._storage_key == tool_files[i].file_key
|
||||||
|
|
||||||
def test_load_storage_keys_tenant_isolation(self):
|
def test_load_storage_keys_tenant_isolation(self, db_session_with_containers: Session):
|
||||||
"""Test that tenant isolation works correctly."""
|
"""Loader should not surface records belonging to a different tenant."""
|
||||||
# Create files for different tenants
|
tenant_id = str(uuid4())
|
||||||
other_tenant_id = str(uuid4())
|
other_tenant_id = str(uuid4())
|
||||||
|
user_id = str(uuid4())
|
||||||
|
|
||||||
# Create upload file for current tenant
|
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
upload_file_current = self._create_upload_file()
|
|
||||||
file_current = self._create_file(
|
file_current = self._create_file(
|
||||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create upload file for other tenant (but don't add to cleanup list)
|
upload_file_other = self._create_upload_file(
|
||||||
upload_file_other = UploadFile(
|
db_session_with_containers,
|
||||||
tenant_id=other_tenant_id,
|
tenant_id,
|
||||||
storage_type=StorageType.LOCAL,
|
user_id,
|
||||||
key="other_tenant_key",
|
override_tenant_id=other_tenant_id,
|
||||||
name="other_file.txt",
|
|
||||||
size=1024,
|
|
||||||
extension=".txt",
|
|
||||||
mime_type="text/plain",
|
|
||||||
created_by_role=CreatorUserRole.ACCOUNT,
|
|
||||||
created_by=self.user_id,
|
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
used=False,
|
|
||||||
)
|
)
|
||||||
upload_file_other.id = str(uuid4())
|
|
||||||
self.session.add(upload_file_other)
|
|
||||||
self.session.flush()
|
|
||||||
|
|
||||||
# Create file for other tenant but try to load with current tenant's loader
|
|
||||||
file_other = self._create_file(
|
file_other = self._create_file(
|
||||||
related_id=upload_file_other.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
tenant_id,
|
||||||
|
related_id=upload_file_other.id,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
override_tenant_id=other_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should raise ValueError due to tenant mismatch
|
loader = StorageKeyLoader(
|
||||||
with pytest.raises(ValueError) as context:
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
self.loader.load_storage_keys([file_other])
|
)
|
||||||
|
|
||||||
assert "Upload file not found for id:" in str(context.value)
|
with pytest.raises(ValueError, match="Upload file not found for id:"):
|
||||||
|
loader.load_storage_keys([file_other])
|
||||||
|
|
||||||
# Current tenant's file should still work
|
# Current-tenant file still resolves correctly
|
||||||
self.loader.load_storage_keys([file_current])
|
loader.load_storage_keys([file_current])
|
||||||
assert file_current._storage_key == upload_file_current.key
|
assert file_current._storage_key == upload_file_current.key
|
||||||
|
|
||||||
def test_load_storage_keys_mixed_tenant_batch(self):
|
def test_load_storage_keys_mixed_tenant_batch(self, db_session_with_containers: Session):
|
||||||
"""Test batch with mixed tenant files (should fail on first mismatch)."""
|
"""A batch containing a foreign-tenant file should fail on the mismatch."""
|
||||||
# Create files for current tenant
|
tenant_id = str(uuid4())
|
||||||
upload_file_current = self._create_upload_file()
|
user_id = str(uuid4())
|
||||||
|
|
||||||
|
upload_file_current = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
file_current = self._create_file(
|
file_current = self._create_file(
|
||||||
related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
tenant_id, related_id=upload_file_current.id, transfer_method=FileTransferMethod.LOCAL_FILE
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create file for different tenant
|
|
||||||
other_tenant_id = str(uuid4())
|
|
||||||
file_other = self._create_file(
|
file_other = self._create_file(
|
||||||
related_id=str(uuid4()), transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=other_tenant_id
|
tenant_id,
|
||||||
|
related_id=str(uuid4()),
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
override_tenant_id=str(uuid4()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should raise ValueError on tenant mismatch
|
loader = StorageKeyLoader(
|
||||||
with pytest.raises(ValueError) as context:
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
self.loader.load_storage_keys([file_current, file_other])
|
)
|
||||||
|
with pytest.raises(ValueError, match="Upload file not found for id:"):
|
||||||
|
loader.load_storage_keys([file_current, file_other])
|
||||||
|
|
||||||
assert "Upload file not found for id:" in str(context.value)
|
def test_load_storage_keys_duplicate_file_ids(self, db_session_with_containers: Session):
|
||||||
|
"""Duplicate file IDs in the batch should be handled gracefully."""
|
||||||
|
tenant_id = str(uuid4())
|
||||||
|
user_id = str(uuid4())
|
||||||
|
|
||||||
def test_load_storage_keys_duplicate_file_ids(self):
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
"""Test handling of duplicate file IDs in the batch."""
|
file1 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
# Create upload file
|
file2 = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
upload_file = self._create_upload_file()
|
|
||||||
|
|
||||||
# Create two File objects with same related_id
|
loader = StorageKeyLoader(
|
||||||
file1 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
db_session_with_containers, tenant_id, access_controller=DatabaseFileAccessController()
|
||||||
file2 = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
)
|
||||||
|
loader.load_storage_keys([file1, file2])
|
||||||
|
|
||||||
# Should handle duplicates gracefully
|
|
||||||
self.loader.load_storage_keys([file1, file2])
|
|
||||||
|
|
||||||
# Both files should have the same storage key
|
|
||||||
assert file1._storage_key == upload_file.key
|
assert file1._storage_key == upload_file.key
|
||||||
assert file2._storage_key == upload_file.key
|
assert file2._storage_key == upload_file.key
|
||||||
|
|
||||||
def test_load_storage_keys_session_isolation(self):
|
def test_load_storage_keys_session_isolation(self, db_session_with_containers: Session):
|
||||||
"""Test that the loader uses the provided session correctly."""
|
"""A loader backed by an uncommitted session should not see data from another session."""
|
||||||
# Create test data
|
tenant_id = str(uuid4())
|
||||||
upload_file = self._create_upload_file()
|
user_id = str(uuid4())
|
||||||
file = self._create_file(related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
|
||||||
|
|
||||||
# Create loader with different session (same underlying connection)
|
upload_file = self._create_upload_file(db_session_with_containers, tenant_id, user_id)
|
||||||
|
file = self._create_file(tenant_id, related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE)
|
||||||
|
|
||||||
|
# A loader with a fresh, separate session cannot see uncommitted rows from db_session_with_containers
|
||||||
with Session(bind=db.engine) as other_session:
|
with Session(bind=db.engine) as other_session:
|
||||||
other_loader = StorageKeyLoader(
|
other_loader = StorageKeyLoader(
|
||||||
other_session,
|
other_session,
|
||||||
self.tenant_id,
|
tenant_id,
|
||||||
access_controller=DatabaseFileAccessController(),
|
access_controller=DatabaseFileAccessController(),
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import NotRequired, TypedDict
|
|||||||
|
|
||||||
class AdminConfig(TypedDict):
|
class AdminConfig(TypedDict):
|
||||||
"""Configuration for admin section."""
|
"""Configuration for admin section."""
|
||||||
|
|
||||||
username: str
|
username: str
|
||||||
password: str
|
password: str
|
||||||
base_url: str
|
base_url: str
|
||||||
@@ -14,6 +15,7 @@ class AdminConfig(TypedDict):
|
|||||||
|
|
||||||
class AuthConfig(TypedDict):
|
class AuthConfig(TypedDict):
|
||||||
"""Configuration for authentication section."""
|
"""Configuration for authentication section."""
|
||||||
|
|
||||||
access_token: str
|
access_token: str
|
||||||
refresh_token: NotRequired[str]
|
refresh_token: NotRequired[str]
|
||||||
expires_at: NotRequired[int]
|
expires_at: NotRequired[int]
|
||||||
@@ -21,6 +23,7 @@ class AuthConfig(TypedDict):
|
|||||||
|
|
||||||
class AppConfig(TypedDict):
|
class AppConfig(TypedDict):
|
||||||
"""Configuration for app section."""
|
"""Configuration for app section."""
|
||||||
|
|
||||||
app_id: str
|
app_id: str
|
||||||
app_name: NotRequired[str]
|
app_name: NotRequired[str]
|
||||||
description: NotRequired[str]
|
description: NotRequired[str]
|
||||||
@@ -28,6 +31,7 @@ class AppConfig(TypedDict):
|
|||||||
|
|
||||||
class ApiKeyConfig(TypedDict):
|
class ApiKeyConfig(TypedDict):
|
||||||
"""Configuration for API key section."""
|
"""Configuration for API key section."""
|
||||||
|
|
||||||
token: str
|
token: str
|
||||||
key_name: NotRequired[str]
|
key_name: NotRequired[str]
|
||||||
expires_at: NotRequired[int]
|
expires_at: NotRequired[int]
|
||||||
@@ -35,6 +39,7 @@ class ApiKeyConfig(TypedDict):
|
|||||||
|
|
||||||
class StressTestState(TypedDict):
|
class StressTestState(TypedDict):
|
||||||
"""Complete stress test state structure."""
|
"""Complete stress test state structure."""
|
||||||
|
|
||||||
admin: NotRequired[AdminConfig]
|
admin: NotRequired[AdminConfig]
|
||||||
auth: NotRequired[AuthConfig]
|
auth: NotRequired[AuthConfig]
|
||||||
app: NotRequired[AppConfig]
|
app: NotRequired[AppConfig]
|
||||||
|
|||||||
Reference in New Issue
Block a user