From 1b972c4e0977af6fd0c76fe28785fccd7680b5a3 Mon Sep 17 00:00:00 2001 From: chariri Date: Wed, 3 Jun 2026 13:24:17 +0900 Subject: [PATCH] refactor(api): migrate tenant/user via DI for several endpoints (#36971) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../console/app/agent_app_access.py | 8 +- .../console/app/agent_app_feature.py | 14 +++- api/controllers/console/billing/billing.py | 26 ++++--- api/controllers/console/billing/compliance.py | 16 +++- .../rag_pipeline/rag_pipeline_datasets.py | 16 ++-- api/controllers/console/human_input_form.py | 42 +++++++---- .../console/workspace/agent_providers.py | 27 ++++--- .../workspace/load_balancing_config.py | 21 ++++-- .../test_rag_pipeline_datasets.py | 75 +++++++------------ .../console/billing/test_billing.py | 14 ++-- .../console/test_human_input_form.py | 55 ++++---------- .../console/workspace/test_agent_providers.py | 54 +------------ .../workspace/test_load_balancing_config.py | 4 +- api/uv.lock | 1 + .../src/dify_agent/layers/shell/layer.py | 8 +- .../dify_agent/layers/shell/test_layer.py | 6 +- .../runtime/test_compositor_factory.py | 8 +- .../local/dify_agent/runtime/test_runner.py | 3 +- 18 files changed, 181 insertions(+), 217 deletions(-) diff --git a/api/controllers/console/app/agent_app_access.py b/api/controllers/console/app/agent_app_access.py index bab7752fc9..97ff134490 100644 --- a/api/controllers/console/app/agent_app_access.py +++ b/api/controllers/console/app/agent_app_access.py @@ -11,10 +11,10 @@ from pydantic import Field from controllers.common.schema import register_response_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id from extensions.ext_database import db from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from models.model import App, AppMode from services.agent.roster_service import AgentRosterService @@ -49,8 +49,8 @@ class AgentAppReferencingWorkflowsResource(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT]) - def get(self, app_model: App): - _, tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, tenant_id: str, app_model: App): workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent( tenant_id=tenant_id, app_id=app_model.id ) diff --git a/api/controllers/console/app/agent_app_feature.py b/api/controllers/console/app/agent_app_feature.py index 51142a1b83..fbedddaf0d 100644 --- a/api/controllers/console/app/agent_app_feature.py +++ b/api/controllers/console/app/agent_app_feature.py @@ -18,9 +18,15 @@ from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_user, +) from events.app_event import app_model_config_was_updated -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.model import App, AppMode from services.agent_app_feature_service import AgentAppFeatureConfigService @@ -65,9 +71,9 @@ class AgentAppFeatureConfigResource(Resource): @edit_permission_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT]) - def post(self, app_model: App): + @with_current_user + def post(self, current_user: Account, app_model: App): args = AgentAppFeaturesRequest.model_validate(console_ns.payload) - current_user, _ = current_account_with_tenant() new_app_model_config = AgentAppFeatureConfigService.update_features( app_model=app_model, diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 45de338559..fdd4c27652 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required +from controllers.console.wraps import ( + account_initialization_required, + only_edition_cloud, + setup_required, + with_current_tenant_id, + with_current_user, +) from enums.cloud_plan import CloudPlan -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.billing_service import BillingService @@ -32,8 +39,9 @@ class Subscription(Resource): @login_required @account_initialization_required @only_edition_cloud - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) BillingService.is_tenant_owner_or_admin(current_user) return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id) @@ -45,8 +53,9 @@ class Invoices(Resource): @login_required @account_initialization_required @only_edition_cloud - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): BillingService.is_tenant_owner_or_admin(current_user) return BillingService.get_invoices(current_user.email, current_tenant_id) @@ -63,9 +72,8 @@ class PartnerTenants(Resource): @login_required @account_initialization_required @only_edition_cloud - def put(self, partner_key: str): - current_user, _ = current_account_with_tenant() - + @with_current_user + def put(self, current_user: Account, partner_key: str): try: args = PartnerTenantsPayload.model_validate(console_ns.payload or {}) click_id = args.click_id diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py index b5a08e0791..3d528e1ddd 100644 --- a/api/controllers/console/billing/compliance.py +++ b/api/controllers/console/billing/compliance.py @@ -3,11 +3,18 @@ from flask_restx import Resource from pydantic import BaseModel, Field from libs.helper import extract_remote_ip -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.billing_service import BillingService from .. import console_ns -from ..wraps import account_initialization_required, only_edition_cloud, setup_required +from ..wraps import ( + account_initialization_required, + only_edition_cloud, + setup_required, + with_current_tenant_id, + with_current_user, +) class ComplianceDownloadQuery(BaseModel): @@ -29,8 +36,9 @@ class ComplianceApi(Resource): @login_required @account_initialization_required @only_edition_cloud - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) ip_address = extract_remote_ip(request) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 39c8aaa451..e5164bb5d9 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -11,10 +11,13 @@ from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.dataset_fields import dataset_detail_fields -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.dataset import DatasetPermissionEnum from services.dataset_service import DatasetPermissionService, DatasetService from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity @@ -35,9 +38,10 @@ class CreateRagPipelineDatasetApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {}) - current_user, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() @@ -79,10 +83,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator - current_user, current_tenant_id = current_account_with_tenant() - if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.create_empty_rag_pipeline_dataset( diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 2ce6bc3e6d..4b34cb6d9c 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -14,7 +14,13 @@ from sqlalchemy.orm import Session, sessionmaker from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, model_validate, setup_required +from controllers.console.wraps import ( + account_initialization_required, + model_validate, + setup_required, + with_current_tenant_id, + with_current_user, +) from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator @@ -23,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db -from libs.login import current_account_with_tenant, login_required -from models import App +from libs.login import login_required +from models import Account, App from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun @@ -48,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource): """Console API for getting human input form definition.""" @staticmethod - def _ensure_console_access(form: Form): - _, current_tenant_id = current_account_with_tenant() - + def _ensure_console_access(form: Form, current_tenant_id: str) -> None: + """Ensure a console form token resolves only inside the current tenant.""" if form.tenant_id != current_tenant_id: raise NotFoundError("App not found") @@ -62,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, form_token: str): + @with_current_tenant_id + def get(self, current_tenant_id: str, form_token: str): """ Get human input form definition by form token. @@ -73,15 +79,23 @@ class ConsoleHumanInputFormApi(Resource): if form is None: raise NotFoundError(f"form not found, token={form_token}") - self._ensure_console_access(form) + self._ensure_console_access(form, current_tenant_id) return _jsonify_form_definition(form) @account_initialization_required @login_required + @with_current_user + @with_current_tenant_id @model_validate(HumanInputFormSubmitPayload) @console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__]) - def post(self, payload: HumanInputFormSubmitPayload, form_token: str): + def post( + self, + payload: HumanInputFormSubmitPayload, + current_tenant_id: str, + current_user: Account, + form_token: str, + ): """ Submit human input form by form token. @@ -95,14 +109,12 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - current_user, _ = current_account_with_tenant() - service = HumanInputService(db.engine) form = service.get_form_by_token(form_token) if form is None: raise NotFoundError(f"form not found, token={form_token}") - self._ensure_console_access(form) + self._ensure_console_access(form, current_tenant_id) self._ensure_console_recipient_type(form) recipient_type = form.recipient_type # The type checker is not smart enought to validate the following invariant. @@ -126,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource): @account_initialization_required @login_required - def get(self, workflow_run_id: str): + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, workflow_run_id: str): """ Get workflow execution events stream after resume. @@ -134,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource): Returns Server-Sent Events stream. """ - - user, tenant_id = current_account_with_tenant() session_maker = sessionmaker(db.engine) repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) workflow_run = repo.get_workflow_run_by_id_and_tenant_id( diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 764f488755..8e968bb07f 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,9 +1,15 @@ from flask_restx import Resource, fields from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.agent_service import AgentService @@ -19,14 +25,10 @@ class AgentProviderListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - current_user, current_tenant_id = current_account_with_tenant() - user = current_user - - user_id = user.id - tenant_id = current_tenant_id - - return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id)) + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): + return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id)) @console_ns.route("/workspaces/current/agent-provider/") @@ -42,6 +44,7 @@ class AgentProviderApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider_name: str): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, provider_name: str): return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name)) diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 2a6f37aec8..969483d138 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.validate import CredentialsValidateFailedError -from libs.login import current_account_with_tenant, login_required -from models import TenantAccountRole +from libs.login import login_required +from models import Account, TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService @@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider: str): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, provider: str): if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() @@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider: str, config_id: str): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str): if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index 7624c1150f..972043c022 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -2,6 +2,8 @@ from __future__ import annotations +from collections.abc import Callable +from typing import Protocol, cast from unittest.mock import MagicMock, patch import pytest @@ -17,21 +19,26 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import ( ) -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +class _WrappedCallable(Protocol): + __wrapped__: Callable[..., object] + + +def unwrap(func: Callable[..., object]) -> Callable[..., object]: + current: Callable[..., object] | _WrappedCallable = func + while hasattr(current, "__wrapped__"): + current = cast(_WrappedCallable, current).__wrapped__ + return cast(Callable[..., object], current) class TestCreateRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def _valid_payload(self): + def _valid_payload(self) -> dict[str, str]: return {"yaml_content": "name: test"} - def test_post_success(self, app: Flask): + def test_post_success(self, app: Flask) -> None: api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -45,21 +52,17 @@ class TestCreateRagPipelineDatasetApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, ), ): - response, status = method(api) + response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user)) assert status == 201 assert response == import_info - def test_post_forbidden_non_editor(self, app: Flask): + def test_post_forbidden_non_editor(self, app: Flask) -> None: api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -69,15 +72,11 @@ class TestCreateRagPipelineDatasetApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), ): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", user) - def test_post_dataset_name_duplicate(self, app: Flask): + def test_post_dataset_name_duplicate(self, app: Flask) -> None: api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -90,43 +89,35 @@ class TestCreateRagPipelineDatasetApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", return_value=mock_service, ), ): with pytest.raises(DatasetNameDuplicateError): - method(api) + method(api, "tenant-1", user) - def test_post_invalid_payload(self, app: Flask): + def test_post_invalid_payload(self, app: Flask) -> None: api = CreateRagPipelineDatasetApi() method = unwrap(api.post) - payload = {} + payload: dict[str, str] = {} user = MagicMock(is_dataset_editor=True) with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api) + method(api, "tenant-1", user) class TestCreateEmptyRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_post_success(self, app: Flask): + def test_post_success(self, app: Flask) -> None: api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) @@ -135,10 +126,6 @@ class TestCreateEmptyRagPipelineDatasetApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset", return_value=dataset, @@ -148,23 +135,17 @@ class TestCreateEmptyRagPipelineDatasetApi: return_value={"id": "ds-1"}, ), ): - response, status = method(api) + response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user)) assert status == 201 assert response == {"id": "ds-1"} - def test_post_forbidden_non_editor(self, app: Flask): + def test_post_forbidden_non_editor(self, app: Flask) -> None: api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) user = MagicMock(is_dataset_editor=False) - with ( - app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), - ): + with app.test_request_context("/"): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", user) diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index defa9064fd..52a8513672 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -65,7 +65,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -92,7 +92,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -116,7 +116,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -158,7 +158,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -189,7 +189,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -215,7 +215,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), @@ -241,7 +241,7 @@ class TestPartnerTenants: ): with ( patch( - "controllers.console.billing.billing.current_account_with_tenant", + "controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant-456"), ), patch("libs.login._get_user", return_value=mock_account), diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index 11c9c0275b..80a688ab0e 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -43,10 +43,9 @@ def test_jsonify_form_definition() -> None: def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None: form = SimpleNamespace(tenant_id="tenant-1") - monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2")) with pytest.raises(NotFoundError): - ConsoleHumanInputFormApi._ensure_console_access(form) + ConsoleHumanInputFormApi._ensure_console_access(form, "tenant-2") def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @@ -62,14 +61,13 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch return form monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() handler = _unwrap(api.get) with app.test_request_context("/console/api/form/human_input/token", method="GET"): - response = handler(api, form_token="token") + response = handler(api, "tenant-1", form_token="token") payload = json.loads(response.get_data(as_text=True)) assert payload["fields"] == ["a"] @@ -84,7 +82,6 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat return None monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() @@ -92,7 +89,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat with app.test_request_context("/console/api/form/human_input/token", method="GET"): with pytest.raises(NotFoundError): - handler(api, form_token="token") + handler(api, "tenant-1", form_token="token") def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @@ -106,10 +103,6 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey return form monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="user-1"), "tenant-1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() @@ -124,6 +117,8 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey handler( api, HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + "tenant-1", + SimpleNamespace(id="user-1"), form_token="token", ) @@ -139,10 +134,6 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest return form monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="user-1"), "tenant-1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() @@ -157,6 +148,8 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest handler( api, HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + "tenant-1", + SimpleNamespace(id="user-1"), form_token="token", ) @@ -176,10 +169,6 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: submit_mock(**kwargs) monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="user-1"), "tenant-1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() @@ -193,6 +182,8 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: response = handler( api, HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}), + "tenant-1", + SimpleNamespace(id="user-1"), form_token="token", ) @@ -216,10 +207,6 @@ def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypa submit_mock(**kwargs) monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (current_user, "tenant-1"), - ) monkeypatch.setattr( "controllers.console.wraps.current_account_with_tenant", lambda: (current_user, "tenant-1"), @@ -254,10 +241,6 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) "create_api_workflow_run_repository", lambda *_args, **_kwargs: _RepoStub(), ) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="u1"), "t1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() @@ -265,7 +248,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): - handler(api, workflow_run_id="run-1") + handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1") def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @@ -285,10 +268,6 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey "create_api_workflow_run_repository", lambda *_args, **_kwargs: _RepoStub(), ) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="u1"), "t1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() @@ -296,7 +275,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): - handler(api, workflow_run_id="run-1") + handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1") def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @@ -316,10 +295,6 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey "create_api_workflow_run_repository", lambda *_args, **_kwargs: _RepoStub(), ) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="u1"), "t1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() @@ -327,7 +302,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): - handler(api, workflow_run_id="run-1") + handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1") def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: @@ -364,17 +339,13 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) - "workflow_run_result_to_finish_response", lambda **_kwargs: response_obj, ) - monkeypatch.setattr( - "controllers.console.human_input_form.current_account_with_tenant", - lambda: (SimpleNamespace(id="user-1"), "t1"), - ) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() handler = _unwrap(api.get) with app.test_request_context("/console/api/workflow/run/events", method="GET"): - response = handler(api, workflow_run_id="run-1") + response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1") assert response.mimetype == "text/event-stream" assert "data" in response.get_data(as_text=True) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py index eb0ca15d2e..f70450955a 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -1,9 +1,7 @@ from unittest.mock import MagicMock, patch -import pytest from flask import Flask -from controllers.console.error import AccountNotFound from controllers.console.workspace.agent_providers import ( AgentProviderApi, AgentProviderListApi, @@ -27,16 +25,12 @@ class TestAgentProviderListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch( "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", return_value=providers, ), ): - result = method(api) + result = method(api, tenant_id, user) assert result == providers @@ -49,33 +43,15 @@ class TestAgentProviderListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch( "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", return_value=[], ), ): - result = method(api) + result = method(api, tenant_id, user) assert result == [] - def test_get_account_not_found(self, app: Flask): - api = AgentProviderListApi() - method = unwrap(api.get) - - with ( - app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - side_effect=AccountNotFound(), - ), - ): - with pytest.raises(AccountNotFound): - method(api) - class TestAgentProviderApi: def test_get_success(self, app: Flask): @@ -89,16 +65,12 @@ class TestAgentProviderApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch( "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", return_value=provider_data, ), ): - result = method(api, provider_name) + result = method(api, tenant_id, user, provider_name) assert result == provider_data @@ -112,29 +84,11 @@ class TestAgentProviderApi: with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch( "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", return_value=None, ), ): - result = method(api, provider_name) + result = method(api, tenant_id, user, provider_name) assert result is None - - def test_get_account_not_found(self, app: Flask): - api = AgentProviderApi() - method = unwrap(api.get) - - with ( - app.test_request_context("/"), - patch( - "controllers.console.workspace.agent_providers.current_account_with_tenant", - side_effect=AccountNotFound(), - ), - ): - with pytest.raises(AccountNotFound): - method(api, "openai") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index b2f949c6e2..a1d08849ee 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -63,7 +63,9 @@ def _mock_user(role: TenantAccountRole) -> SimpleNamespace: def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER): user = _mock_user(role) - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123")) + from controllers.console import wraps + + monkeypatch.setattr(wraps, "current_account_with_tenant", lambda: (user, "tenant-123")) mock_service = MagicMock() monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service) return mock_service diff --git a/api/uv.lock b/api/uv.lock index 1dbbd21356..8c7d2b2172 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1300,6 +1300,7 @@ requires-dist = [ { name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" }, { name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" }, { name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" }, + { name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" }, { name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" }, { name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" }, ] diff --git a/dify-agent/src/dify_agent/layers/shell/layer.py b/dify-agent/src/dify_agent/layers/shell/layer.py index d265075631..5e31c0fb39 100644 --- a/dify-agent/src/dify_agent/layers/shell/layer.py +++ b/dify-agent/src/dify_agent/layers/shell/layer.py @@ -256,9 +256,7 @@ class DifyShellRuntimeState(BaseModel): raise ValueError("workspace_cwd requires a matching session_id.") expected_workspace = _workspace_cwd(self.session_id) if self.workspace_cwd != expected_workspace: - raise ValueError( - f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}." - ) + raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.") unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids) if unknown_offset_job_ids: names = ", ".join(sorted(unknown_offset_job_ids)) @@ -694,12 +692,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str: of silently reusing another session's workspace. """ safe_session_id = _validated_session_id(session_id) - workspace_dir = f'$HOME/workspace/{safe_session_id}' + workspace_dir = f"$HOME/workspace/{safe_session_id}" return ( 'mkdir -p "$HOME/workspace"; ' f'if mkdir "{workspace_dir}"; then exit 0; fi; ' f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; ' - 'exit 1' + "exit 1" ) diff --git a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py index a2ab4e435c..3d4ae3221e 100644 --- a/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py +++ b/dify-agent/tests/local/dify_agent/layers/shell/test_layer.py @@ -277,6 +277,7 @@ def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None return next(clients) compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))]) + async def scenario() -> None: async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run: shell_layer = run.get_layer("shell", DifyShellLayer) @@ -342,7 +343,10 @@ def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_an assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")] assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"} - assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls) + assert all( + client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) + for call in client.delete_calls + ) assert all(call.force is True for call in client.delete_calls) assert layer.runtime_state.job_ids == [] assert layer.runtime_state.job_offsets == {} diff --git a/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py b/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py index 799ec94292..8808cf7a96 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_compositor_factory.py @@ -27,7 +27,9 @@ def test_default_layer_providers_register_shell_layer_with_configured_token_fact return factory - monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory) + monkeypatch.setattr( + compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory + ) providers = create_default_layer_providers( shellctl_entrypoint="http://shellctl.example", @@ -56,7 +58,9 @@ def test_default_layer_providers_keep_empty_shellctl_token_by_default( return factory - monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory) + monkeypatch.setattr( + compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory + ) providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example") shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID) diff --git a/dify-agent/tests/local/dify_agent/runtime/test_runner.py b/dify-agent/tests/local/dify_agent/runtime/test_runner.py index 4a899f0a79..ec2acae15c 100644 --- a/dify-agent/tests/local/dify_agent/runtime/test_runner.py +++ b/dify-agent/tests/local/dify_agent/runtime/test_runner.py @@ -684,7 +684,8 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers( ), ) layer_providers = tuple( - provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused") + provider + for provider in create_default_layer_providers(shellctl_entrypoint="http://unused") if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID ) + (shell_provider,)