diff --git a/api/tests/test_containers_integration_tests/.ruff.toml b/api/tests/test_containers_integration_tests/.ruff.toml new file mode 100644 index 0000000000..be0109f462 --- /dev/null +++ b/api/tests/test_containers_integration_tests/.ruff.toml @@ -0,0 +1,19 @@ +extend = "../../.ruff.toml" +src = ["../.."] + +[lint] +extend-select = ["ANN401", "TID251"] + +[lint.per-file-ignores] +"**/*.py" = ["S110", "T201"] +"core/rag/pipeline/test_queue_integration.py" = ["ANN401", "TID251"] +"models/test_types_enum_text.py" = ["ANN401", "TID251"] +"services/test_app_dsl_service.py" = ["ANN401", "TID251"] +"services/test_file_service_zip_and_lookup.py" = ["ANN401", "TID251"] +"services/test_hit_testing_service.py" = ["ANN401", "TID251"] +"services/test_recommended_app_service.py" = ["ANN401", "TID251"] +"trigger/conftest.py" = ["ANN401", "TID251"] +"trigger/test_trigger_e2e.py" = ["ANN401", "TID251"] + +[lint.flake8-tidy-imports.banned-api."typing.Any"] +msg = "Use object, Protocol, TypedDict, TypeVar, ParamSpec, or a localized cast instead." diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index dd742f99d0..2ee9ae68b2 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -41,7 +41,7 @@ SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" class _CloserProtocol(Protocol): """_Closer is any type which implement the close() method.""" - def close(self): + def close(self) -> None: """close the current object, release any external resouece (file, transaction, connection etc.) associated with it. """ @@ -67,7 +67,7 @@ class DifyTestContainers: caches, and search engines. """ - def __init__(self): + def __init__(self) -> None: """Initialize container management with default configurations.""" self.network: Network | None = None self.postgres: PostgresContainer | None = None @@ -77,7 +77,7 @@ class DifyTestContainers: self._containers_started = False logger.info("DifyTestContainers initialized - ready to manage test containers") - def start_containers_with_env(self): + def start_containers_with_env(self) -> None: """ Start all required containers for integration testing. @@ -199,6 +199,8 @@ class DifyTestContainers: # Get container internal network addresses postgres_container_name = self.postgres.get_wrapped_container().name redis_container_name = self.redis.get_wrapped_container().name + assert postgres_container_name is not None + assert redis_container_name is not None self.dify_plugin_daemon.env = { "DB_HOST": postgres_container_name, # Use container name for internal network communication @@ -251,7 +253,7 @@ class DifyTestContainers: self._containers_started = True logger.info("All test containers started successfully") - def stop_containers(self): + def stop_containers(self) -> None: """ Stop and clean up all test containers. @@ -290,7 +292,7 @@ def _get_migration_dir() -> Path: return conftest_dir.parent.parent / "migrations" -def _get_engine_url(engine: Engine): +def _get_engine_url(engine: Engine) -> str: try: return engine.url.render_as_string(hide_password=False).replace("%", "%%") except AttributeError: @@ -409,7 +411,7 @@ def set_up_containers_and_env() -> Generator[DifyTestContainers, None, None]: @pytest.fixture(scope="session") -def flask_app_with_containers(set_up_containers_and_env) -> Flask: +def flask_app_with_containers(set_up_containers_and_env: DifyTestContainers) -> Flask: """ Session-scoped Flask application fixture using test containers. @@ -552,6 +554,7 @@ def isolate_container_database(request: pytest.FixtureRequest) -> Generator[None return app = request.getfixturevalue("flask_app_with_containers") + assert isinstance(app, Flask) try: _truncate_container_database(app) finally: @@ -559,7 +562,7 @@ def isolate_container_database(request: pytest.FixtureRequest) -> Generator[None @pytest.fixture(scope="package", autouse=True) -def mock_ssrf_proxy_requests(): +def mock_ssrf_proxy_requests() -> Generator[None, None, None]: """ Avoid outbound network during containerized tests by stubbing SSRF proxy helpers. """ @@ -568,7 +571,7 @@ def mock_ssrf_proxy_requests(): import httpx - def _fake_request(method, url, **kwargs): + def _fake_request(method: str, url: str, **_kwargs: object) -> httpx.Response: request = httpx.Request(method=method, url=url) return httpx.Response(200, request=request, content=b"") diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py index 3c1688293e..1c84b70b08 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_members.py @@ -1,9 +1,14 @@ from __future__ import annotations +import inspect +from collections.abc import Callable +from typing import cast from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import HTTPException import services @@ -12,16 +17,33 @@ from controllers.console.workspace import members as members_module from controllers.console.workspace.members import MemberCancelInviteApi, MemberUpdateRoleApi, OwnerTransfer from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus +JsonResponse = dict[str, object] +StatusResponse = tuple[JsonResponse, int] -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func + +def unwrap(func: Callable[..., object]) -> Callable[..., object]: + return cast(Callable[..., object], inspect.unwrap(func)) + + +def unwrap_status_response(func: Callable[..., object]) -> Callable[..., StatusResponse]: + return cast(Callable[..., StatusResponse], inspect.unwrap(func)) + + +def unwrap_json_response(func: Callable[..., object]) -> Callable[..., JsonResponse]: + return cast(Callable[..., JsonResponse], inspect.unwrap(func)) + + +def unwrap_json_or_status_response(func: Callable[..., object]) -> Callable[..., JsonResponse | StatusResponse]: + return cast(Callable[..., JsonResponse | StatusResponse], inspect.unwrap(func)) + + +def unwrap_raises(func: Callable[..., object]) -> Callable[..., object]: + return unwrap(func) class WorkspaceMembersIntegrationFactory: @staticmethod - def create_tenant(db_session_with_containers) -> Tenant: + def create_tenant(db_session_with_containers: Session) -> Tenant: tenant = Tenant(name=f"Tenant {uuid4()}", plan="basic", status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -29,7 +51,7 @@ class WorkspaceMembersIntegrationFactory: @staticmethod def create_account( - db_session_with_containers, + db_session_with_containers: Session, *, email_prefix: str, tenant: Tenant | None = None, @@ -60,7 +82,7 @@ class WorkspaceMembersIntegrationFactory: return account @staticmethod - def create_owner_workspace(db_session_with_containers) -> tuple[Tenant, Account]: + def create_owner_workspace(db_session_with_containers: Session) -> tuple[Tenant, Account]: tenant = WorkspaceMembersIntegrationFactory.create_tenant(db_session_with_containers) owner = WorkspaceMembersIntegrationFactory.create_account( db_session_with_containers, @@ -82,7 +104,7 @@ class WorkspaceMembersIntegrationFactory: return token @staticmethod - def get_join(db_session_with_containers, *, tenant: Tenant, account: Account) -> TenantAccountJoin: + def get_join(db_session_with_containers: Session, *, tenant: Tenant, account: Account) -> TenantAccountJoin: tenant_id = tenant.id account_id = account.id db_session_with_containers.expire_all() @@ -95,9 +117,9 @@ class WorkspaceMembersIntegrationFactory: class TestMemberCancelInviteApiWithContainers: - def test_cancel_success(self, flask_app_with_containers, db_session_with_containers): + def test_cancel_success(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = MemberCancelInviteApi() - method = unwrap(api.delete) + method = unwrap_status_response(api.delete) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account(db_session_with_containers, email_prefix="member") @@ -116,9 +138,9 @@ class TestMemberCancelInviteApiWithContainers: assert called_member.id == member.id assert called_current_user.id == current_user.id - def test_cancel_not_found(self, flask_app_with_containers, db_session_with_containers): + def test_cancel_not_found(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = MemberCancelInviteApi() - method = unwrap(api.delete) + method = unwrap_raises(api.delete) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) @@ -126,9 +148,11 @@ class TestMemberCancelInviteApiWithContainers: with pytest.raises(HTTPException): method(api, current_user, str(uuid4())) - def test_cancel_cannot_operate_self(self, flask_app_with_containers, db_session_with_containers): + def test_cancel_cannot_operate_self( + self, flask_app_with_containers: Flask, db_session_with_containers: Session + ) -> None: api = MemberCancelInviteApi() - method = unwrap(api.delete) + method = unwrap_status_response(api.delete) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account(db_session_with_containers, email_prefix="member") @@ -146,9 +170,9 @@ class TestMemberCancelInviteApiWithContainers: assert status == 400 assert result["code"] == "cannot-operate-self" - def test_cancel_no_permission(self, flask_app_with_containers, db_session_with_containers): + def test_cancel_no_permission(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = MemberCancelInviteApi() - method = unwrap(api.delete) + method = unwrap_status_response(api.delete) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account(db_session_with_containers, email_prefix="member") @@ -166,9 +190,11 @@ class TestMemberCancelInviteApiWithContainers: assert status == 403 assert result["code"] == "forbidden" - def test_cancel_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers): + def test_cancel_member_not_in_tenant( + self, flask_app_with_containers: Flask, db_session_with_containers: Session + ) -> None: api = MemberCancelInviteApi() - method = unwrap(api.delete) + method = unwrap_status_response(api.delete) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account(db_session_with_containers, email_prefix="member") @@ -188,9 +214,9 @@ class TestMemberCancelInviteApiWithContainers: class TestMemberUpdateRoleApiWithContainers: - def test_update_success(self, flask_app_with_containers, db_session_with_containers): + def test_update_success(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = MemberUpdateRoleApi() - method = unwrap(api.put) + method = unwrap_json_or_status_response(api.put) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account( @@ -211,9 +237,11 @@ class TestMemberUpdateRoleApiWithContainers: factory.get_join(db_session_with_containers, tenant=tenant, account=member).role == TenantAccountRole.NORMAL ) - def test_update_member_not_found(self, flask_app_with_containers, db_session_with_containers): + def test_update_member_not_found( + self, flask_app_with_containers: Flask, db_session_with_containers: Session + ) -> None: api = MemberUpdateRoleApi() - method = unwrap(api.put) + method = unwrap_raises(api.put) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) @@ -223,9 +251,9 @@ class TestMemberUpdateRoleApiWithContainers: class TestOwnerTransferApiWithContainers: - def test_member_not_in_tenant(self, flask_app_with_containers, db_session_with_containers): + def test_member_not_in_tenant(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = OwnerTransfer() - method = unwrap(api.post) + method = unwrap_raises(api.post) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account(db_session_with_containers, email_prefix="member") @@ -235,9 +263,9 @@ class TestOwnerTransferApiWithContainers: with pytest.raises(MemberNotInTenantError): method(api, current_user, member.id) - def test_member_not_found(self, flask_app_with_containers, db_session_with_containers): + def test_member_not_found(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = OwnerTransfer() - method = unwrap(api.post) + method = unwrap_raises(api.post) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) token = factory.create_owner_transfer_token(current_user) @@ -246,9 +274,9 @@ class TestOwnerTransferApiWithContainers: with pytest.raises(HTTPException): method(api, current_user, str(uuid4())) - def test_transfer_success(self, flask_app_with_containers, db_session_with_containers): + def test_transfer_success(self, flask_app_with_containers: Flask, db_session_with_containers: Session) -> None: api = OwnerTransfer() - method = unwrap(api.post) + method = unwrap_json_response(api.post) factory = WorkspaceMembersIntegrationFactory tenant, current_user = factory.create_owner_workspace(db_session_with_containers) member = factory.create_account( diff --git a/api/tests/test_containers_integration_tests/pyrefly.toml b/api/tests/test_containers_integration_tests/pyrefly.toml index 777b690108..e8100ee3c9 100644 --- a/api/tests/test_containers_integration_tests/pyrefly.toml +++ b/api/tests/test_containers_integration_tests/pyrefly.toml @@ -1,4 +1,5 @@ preset = "strict" +strict-callable-subtyping = true project-includes = ["."] search-path = ["../.."] @@ -10,7 +11,6 @@ search-path = ["../.."] # rm --force "$tmp_config" project-excludes = [ "commands/test_legacy_model_type_migration.py", - "conftest.py", "controllers/console/app/test_app_apis.py", "controllers/console/app/test_app_import_api.py", "controllers/console/app/test_chat_conversation_status_count_api.py", @@ -28,7 +28,6 @@ project-excludes = [ "controllers/console/explore/test_conversation.py", "controllers/console/test_api_based_extension.py", "controllers/console/test_apikey.py", - "controllers/console/workspace/test_members.py", "controllers/console/workspace/test_tool_provider.py", "controllers/console/workspace/test_trigger_providers.py", "controllers/console/workspace/test_workspace_wraps.py", @@ -66,7 +65,6 @@ project-excludes = [ "services/document_service_status.py", "services/enterprise/test_account_deletion_sync.py", "services/plugin/test_plugin_parameter_service.py", - "services/plugin/test_plugin_permission_service.py", "services/plugin/test_plugin_service.py", "services/rag_pipeline/test_rag_pipeline_service_db.py", "services/recommend_app/test_database_retrieval.py", @@ -122,7 +120,6 @@ project-excludes = [ "services/test_saved_message_service.py", "services/test_schedule_service.py", "services/test_tag_service.py", - "services/test_trigger_provider_service.py", "services/test_web_conversation_service.py", "services/test_webapp_auth_service.py", "services/test_webhook_service.py", @@ -147,7 +144,6 @@ project-excludes = [ "tasks/test_create_segment_to_index_task.py", "tasks/test_dataset_indexing_task.py", "tasks/test_deal_dataset_vector_index_task.py", - "tasks/test_delete_account_task.py", "tasks/test_delete_segment_from_index_task.py", "tasks/test_disable_segment_from_index_task.py", "tasks/test_disable_segments_from_index_task.py", @@ -166,7 +162,6 @@ project-excludes = [ "tasks/test_mail_register_task.py", "tasks/test_rag_pipeline_run_tasks.py", "tasks/test_remove_app_and_related_data_task.py", - "test_container_state_isolation.py", "test_opendal_fs_default_root.py", "test_workflow_pause_integration.py", "trigger/conftest.py", @@ -179,4 +174,7 @@ project-excludes = [ ] [errors] +redundant-cast = true unannotated-return = true +unnecessary-type-conversion = true +unused-ignore = true diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py index 49d06986fd..dfa3bc9f01 100644 --- a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py @@ -27,7 +27,7 @@ def _count_permissions(session: Session, tenant_id: str) -> int: class TestGetPermission: """Integration tests for PluginPermissionService.get_permission using testcontainers.""" - def test_returns_permission_when_found(self, db_session_with_containers: Session): + def test_returns_permission_when_found(self, db_session_with_containers: Session) -> None: tenant_id = _tenant_id() permission = TenantPluginPermission( tenant_id=tenant_id, @@ -45,7 +45,7 @@ class TestGetPermission: assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE - def test_returns_none_when_not_found(self, db_session_with_containers: Session): + def test_returns_none_when_not_found(self, db_session_with_containers: Session) -> None: result = PluginPermissionService.get_permission(_tenant_id()) assert result is None @@ -54,7 +54,7 @@ class TestGetPermission: class TestChangePermission: """Integration tests for PluginPermissionService.change_permission using testcontainers.""" - def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session): + def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session) -> None: tenant_id = _tenant_id() result = PluginPermissionService.change_permission( @@ -69,7 +69,7 @@ class TestChangePermission: assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE - def test_updates_existing_permission(self, db_session_with_containers: Session): + def test_updates_existing_permission(self, db_session_with_containers: Session) -> None: tenant_id = _tenant_id() existing = TenantPluginPermission( tenant_id=tenant_id, diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index aff550c909..0aea7151e9 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from unittest.mock import MagicMock, patch import pytest @@ -7,17 +8,20 @@ from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity +from models.account import Account, Tenant from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService from tests.test_containers_integration_tests.helpers import generate_valid_password +MockExternalServiceDependencies = dict[str, MagicMock] + class TestTriggerProviderService: """Integration tests for TriggerProviderService using testcontainers.""" @pytest.fixture - def mock_external_service_dependencies(self): + def mock_external_service_dependencies(self) -> Generator[MockExternalServiceDependencies, None, None]: """Mock setup for external service dependencies.""" with ( patch("services.trigger.trigger_provider_service.TriggerManager") as mock_trigger_manager, @@ -48,7 +52,11 @@ class TestTriggerProviderService: "account_feature_service": mock_account_feature_service, } - def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): + def _create_test_account_and_tenant( + self, + db_session_with_containers: Session, + mock_external_service_dependencies: MockExternalServiceDependencies, + ) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -80,20 +88,21 @@ class TestTriggerProviderService: ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant + assert tenant is not None return account, tenant def _create_test_subscription( self, db_session_with_containers: Session, - tenant_id, - user_id, - provider_id, - credential_type, - credentials, - mock_external_service_dependencies, + tenant_id: str, + user_id: str, + provider_id: TriggerProviderID, + credential_type: CredentialType, + credentials: dict[str, str], + mock_external_service_dependencies: MockExternalServiceDependencies, name: str | None = None, - ): + ) -> TriggerSubscription: """ Helper method to create a test trigger subscription. @@ -133,7 +142,7 @@ class TestTriggerProviderService: parameters={"param1": "value1"}, properties={"prop1": "value1"}, credentials=dict(credential_encrypter.encrypt(credentials)), - credential_type=credential_type.value, + credential_type=credential_type, credential_expires_at=-1, expires_at=-1, ) @@ -145,8 +154,8 @@ class TestTriggerProviderService: return subscription def test_rebuild_trigger_subscription_success_with_merged_credentials( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test successful rebuild with credential merging (HIDDEN_VALUE handling). @@ -247,8 +256,8 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_with_all_new_credentials( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test rebuild when all credentials are new (no HIDDEN_VALUE). @@ -307,8 +316,8 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "completely-new-secret" def test_rebuild_trigger_subscription_with_all_hidden_values( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). @@ -366,8 +375,8 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. @@ -425,8 +434,8 @@ class TestTriggerProviderService: assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE def test_rebuild_trigger_subscription_rollback_on_error( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test that transaction is rolled back on error. @@ -478,8 +487,8 @@ class TestTriggerProviderService: assert subscription.parameters == original_parameters def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test error when subscription is not found. @@ -504,8 +513,8 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_name_uniqueness_check( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MockExternalServiceDependencies + ) -> None: """ Test that name uniqueness is checked when updating name. diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py index 68737a4ef6..9dfc6325d0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_account_task.py @@ -6,9 +6,12 @@ lookup through the real Testcontainers PostgreSQL session factory instead of a patched session_factory mock. """ +from unittest.mock import MagicMock from uuid import uuid4 import pytest +from _pytest.logging import LogCaptureFixture +from pytest_mock import MockerFixture from sqlalchemy.orm import Session from models.account import Account @@ -26,14 +29,16 @@ def _create_account(db_session: Session, *, email: str = "user@example.com") -> @pytest.fixture -def mock_external_dependencies(mocker): +def mock_external_dependencies(mocker: MockerFixture) -> tuple[MagicMock, MagicMock]: billing_service = mocker.patch("tasks.delete_account_task.BillingService") mail_task = mocker.patch("tasks.delete_account_task.send_deletion_success_task") return billing_service, mail_task def test_billing_enabled_account_exists_calls_billing_and_sends_email( - db_session_with_containers: Session, mock_external_dependencies, mocker + db_session_with_containers: Session, + mock_external_dependencies: tuple[MagicMock, MagicMock], + mocker: MockerFixture, ) -> None: billing_service, mail_task = mock_external_dependencies account = _create_account(db_session_with_containers, email="a@b.com") @@ -46,7 +51,9 @@ def test_billing_enabled_account_exists_calls_billing_and_sends_email( def test_billing_disabled_account_exists_sends_email_only( - db_session_with_containers: Session, mock_external_dependencies, mocker + db_session_with_containers: Session, + mock_external_dependencies: tuple[MagicMock, MagicMock], + mocker: MockerFixture, ) -> None: billing_service, mail_task = mock_external_dependencies account = _create_account(db_session_with_containers, email="x@y.com") @@ -58,7 +65,9 @@ def test_billing_disabled_account_exists_sends_email_only( mail_task.delay.assert_called_once_with(account.email) -def test_billing_enabled_account_not_found_calls_billing_no_email(mock_external_dependencies, mocker, caplog) -> None: +def test_billing_enabled_account_not_found_calls_billing_no_email( + mock_external_dependencies: tuple[MagicMock, MagicMock], mocker: MockerFixture, caplog: LogCaptureFixture +) -> None: billing_service, mail_task = mock_external_dependencies account_id = str(uuid4()) mocker.patch("tasks.delete_account_task.dify_config.BILLING_ENABLED", True) @@ -71,7 +80,9 @@ def test_billing_enabled_account_not_found_calls_billing_no_email(mock_external_ def test_billing_delete_raises_propagates_and_no_email( - db_session_with_containers: Session, mock_external_dependencies, mocker + db_session_with_containers: Session, + mock_external_dependencies: tuple[MagicMock, MagicMock], + mocker: MockerFixture, ) -> None: billing_service, mail_task = mock_external_dependencies account = _create_account(db_session_with_containers, email="err@example.com") diff --git a/api/tests/test_containers_integration_tests/test_container_state_isolation.py b/api/tests/test_containers_integration_tests/test_container_state_isolation.py index 702460c5ad..448fca1202 100644 --- a/api/tests/test_containers_integration_tests/test_container_state_isolation.py +++ b/api/tests/test_containers_integration_tests/test_container_state_isolation.py @@ -2,6 +2,9 @@ from __future__ import annotations from uuid import uuid4 +from flask import Flask +from sqlalchemy.orm import Session + from extensions.ext_redis import redis_client from models.account import Account @@ -10,8 +13,8 @@ REDIS_KEY = f"container-state-isolation:{uuid4()}" def test_1_container_state_can_be_written( - flask_app_with_containers, - db_session_with_containers, + flask_app_with_containers: Flask, + db_session_with_containers: Session, ) -> None: account = Account( name="Container State Isolation", @@ -30,8 +33,8 @@ def test_1_container_state_can_be_written( def test_2_container_state_is_flushed_between_tests( - flask_app_with_containers, - db_session_with_containers, + flask_app_with_containers: Flask, + db_session_with_containers: Session, ) -> None: assert db_session_with_containers.query(Account).filter_by(email=ACCOUNT_EMAIL).one_or_none() is None