refactor(api): migrate tenant/user via DI for several endpoints (#36971)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
chariri
2026-06-03 13:24:17 +09:00
committed by GitHub
parent 7968d2c3c8
commit 1b972c4e09
18 changed files with 181 additions and 217 deletions
@@ -11,10 +11,10 @@ from pydantic import Field
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.model import App, AppMode
from services.agent.roster_service import AgentRosterService
@@ -49,8 +49,8 @@ class AgentAppReferencingWorkflowsResource(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def get(self, app_model: App):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, tenant_id: str, app_model: App):
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
tenant_id=tenant_id, app_id=app_model.id
)
@@ -18,9 +18,15 @@ from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from events.app_event import app_model_config_was_updated
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.model import App, AppMode
from services.agent_app_feature_service import AgentAppFeatureConfigService
@@ -65,9 +71,9 @@ class AgentAppFeatureConfigResource(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model(mode=[AppMode.AGENT])
def post(self, app_model: App):
@with_current_user
def post(self, current_user: Account, app_model: App):
args = AgentAppFeaturesRequest.model_validate(console_ns.payload)
current_user, _ = current_account_with_tenant()
new_app_model_config = AgentAppFeatureConfigService.update_features(
app_model=app_model,
+17 -9
View File
@@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
@@ -32,8 +39,9 @@ class Subscription(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
@@ -45,8 +53,9 @@ class Invoices(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id)
@@ -63,9 +72,8 @@ class PartnerTenants(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
@with_current_user
def put(self, current_user: Account, partner_key: str):
try:
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
click_id = args.click_id
+12 -4
View File
@@ -3,11 +3,18 @@ from flask_restx import Resource
from pydantic import BaseModel, Field
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
from .. import console_ns
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
from ..wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
class ComplianceDownloadQuery(BaseModel):
@@ -29,8 +36,9 @@ class ComplianceApi(Resource):
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
ip_address = extract_remote_ip(request)
@@ -11,10 +11,13 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
@@ -35,9 +38,10 @@ class CreateRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
@@ -79,10 +83,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
+27 -15
View File
@@ -14,7 +14,13 @@ from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
from controllers.console.wraps import (
account_initialization_required,
model_validate,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
@@ -23,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from libs.login import login_required
from models import Account, App
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
@@ -48,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
"""Ensure a console form token resolves only inside the current tenant."""
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@@ -62,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
@with_current_tenant_id
def get(self, current_tenant_id: str, form_token: str):
"""
Get human input form definition by form token.
@@ -73,15 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
@with_current_user
@with_current_tenant_id
@model_validate(HumanInputFormSubmitPayload)
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
def post(
self,
payload: HumanInputFormSubmitPayload,
current_tenant_id: str,
current_user: Account,
form_token: str,
):
"""
Submit human input form by form token.
@@ -95,14 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_access(form, current_tenant_id)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
# The type checker is not smart enought to validate the following invariant.
@@ -126,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
@with_current_user
@with_current_tenant_id
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
@@ -134,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
@@ -1,9 +1,15 @@
from flask_restx import Resource, fields
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.agent_service import AgentService
@@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
user_id = user.id
tenant_id = current_tenant_id
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
@@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
@@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from libs.login import login_required
from models import Account, TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
@@ -2,6 +2,8 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Protocol, cast
from unittest.mock import MagicMock, patch
import pytest
@@ -17,21 +19,26 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class _WrappedCallable(Protocol):
__wrapped__: Callable[..., object]
def unwrap(func: Callable[..., object]) -> Callable[..., object]:
current: Callable[..., object] | _WrappedCallable = func
while hasattr(current, "__wrapped__"):
current = cast(_WrappedCallable, current).__wrapped__
return cast(Callable[..., object], current)
class TestCreateRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def _valid_payload(self):
def _valid_payload(self) -> dict[str, str]:
return {"yaml_content": "name: test"}
def test_post_success(self, app: Flask):
def test_post_success(self, app: Flask) -> None:
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@@ -45,21 +52,17 @@ class TestCreateRagPipelineDatasetApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
response, status = method(api)
response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user))
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app: Flask):
def test_post_forbidden_non_editor(self, app: Flask) -> None:
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@@ -69,15 +72,11 @@ class TestCreateRagPipelineDatasetApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
method(api, "tenant-1", user)
def test_post_dataset_name_duplicate(self, app: Flask):
def test_post_dataset_name_duplicate(self, app: Flask) -> None:
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
@@ -90,43 +89,35 @@ class TestCreateRagPipelineDatasetApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
method(api, "tenant-1", user)
def test_post_invalid_payload(self, app: Flask):
def test_post_invalid_payload(self, app: Flask) -> None:
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = {}
payload: dict[str, str] = {}
user = MagicMock(is_dataset_editor=True)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api)
method(api, "tenant-1", user)
class TestCreateEmptyRagPipelineDatasetApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_post_success(self, app: Flask):
def test_post_success(self, app: Flask) -> None:
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
@@ -135,10 +126,6 @@ class TestCreateEmptyRagPipelineDatasetApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
return_value=dataset,
@@ -148,23 +135,17 @@ class TestCreateEmptyRagPipelineDatasetApi:
return_value={"id": "ds-1"},
),
):
response, status = method(api)
response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user))
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app: Flask):
def test_post_forbidden_non_editor(self, app: Flask) -> None:
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api)
method(api, "tenant-1", user)
@@ -65,7 +65,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -92,7 +92,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -116,7 +116,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -158,7 +158,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -189,7 +189,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -215,7 +215,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -241,7 +241,7 @@ class TestPartnerTenants:
):
with (
patch(
"controllers.console.billing.billing.current_account_with_tenant",
"controllers.console.wraps.current_account_with_tenant",
return_value=(mock_account, "tenant-456"),
),
patch("libs.login._get_user", return_value=mock_account),
@@ -43,10 +43,9 @@ def test_jsonify_form_definition() -> None:
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1")
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
with pytest.raises(NotFoundError):
ConsoleHumanInputFormApi._ensure_console_access(form)
ConsoleHumanInputFormApi._ensure_console_access(form, "tenant-2")
def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@@ -62,14 +61,13 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
response = handler(api, form_token="token")
response = handler(api, "tenant-1", form_token="token")
payload = json.loads(response.get_data(as_text=True))
assert payload["fields"] == ["a"]
@@ -84,7 +82,6 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
return None
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
@@ -92,7 +89,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
handler(api, "tenant-1", form_token="token")
def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@@ -106,10 +103,6 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
@@ -124,6 +117,8 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
"tenant-1",
SimpleNamespace(id="user-1"),
form_token="token",
)
@@ -139,10 +134,6 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
@@ -157,6 +148,8 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
"tenant-1",
SimpleNamespace(id="user-1"),
form_token="token",
)
@@ -176,10 +169,6 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
@@ -193,6 +182,8 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
response = handler(
api,
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
"tenant-1",
SimpleNamespace(id="user-1"),
form_token="token",
)
@@ -216,10 +207,6 @@ def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypa
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
monkeypatch.setattr(
"controllers.console.wraps.current_account_with_tenant",
lambda: (current_user, "tenant-1"),
@@ -254,10 +241,6 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
@@ -265,7 +248,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@@ -285,10 +268,6 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
@@ -296,7 +275,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@@ -316,10 +295,6 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
@@ -327,7 +302,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
@@ -364,17 +339,13 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -
"workflow_run_result_to_finish_response",
lambda **_kwargs: response_obj,
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
response = handler(api, workflow_run_id="run-1")
response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1")
assert response.mimetype == "text/event-stream"
assert "data" in response.get_data(as_text=True)
@@ -1,9 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
AgentProviderApi,
AgentProviderListApi,
@@ -27,16 +25,12 @@ class TestAgentProviderListApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=providers,
),
):
result = method(api)
result = method(api, tenant_id, user)
assert result == providers
@@ -49,33 +43,15 @@ class TestAgentProviderListApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=[],
),
):
result = method(api)
result = method(api, tenant_id, user)
assert result == []
def test_get_account_not_found(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api)
class TestAgentProviderApi:
def test_get_success(self, app: Flask):
@@ -89,16 +65,12 @@ class TestAgentProviderApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=provider_data,
),
):
result = method(api, provider_name)
result = method(api, tenant_id, user, provider_name)
assert result == provider_data
@@ -112,29 +84,11 @@ class TestAgentProviderApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=None,
),
):
result = method(api, provider_name)
result = method(api, tenant_id, user, provider_name)
assert result is None
def test_get_account_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api, "openai")
@@ -63,7 +63,9 @@ def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
user = _mock_user(role)
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
from controllers.console import wraps
monkeypatch.setattr(wraps, "current_account_with_tenant", lambda: (user, "tenant-123"))
mock_service = MagicMock()
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
return mock_service
Generated
+1
View File
@@ -1300,6 +1300,7 @@ requires-dist = [
{ name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" },
{ name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" },
{ name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" },
{ name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" },
{ name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" },
{ name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" },
]
@@ -256,9 +256,7 @@ class DifyShellRuntimeState(BaseModel):
raise ValueError("workspace_cwd requires a matching session_id.")
expected_workspace = _workspace_cwd(self.session_id)
if self.workspace_cwd != expected_workspace:
raise ValueError(
f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}."
)
raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.")
unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids)
if unknown_offset_job_ids:
names = ", ".join(sorted(unknown_offset_job_ids))
@@ -694,12 +692,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str:
of silently reusing another session's workspace.
"""
safe_session_id = _validated_session_id(session_id)
workspace_dir = f'$HOME/workspace/{safe_session_id}'
workspace_dir = f"$HOME/workspace/{safe_session_id}"
return (
'mkdir -p "$HOME/workspace"; '
f'if mkdir "{workspace_dir}"; then exit 0; fi; '
f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; '
'exit 1'
"exit 1"
)
@@ -277,6 +277,7 @@ def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None
return next(clients)
compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))])
async def scenario() -> None:
async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run:
shell_layer = run.get_layer("shell", DifyShellLayer)
@@ -342,7 +343,10 @@ def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_an
assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")]
assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"}
assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls)
assert all(
client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job"))
for call in client.delete_calls
)
assert all(call.force is True for call in client.delete_calls)
assert layer.runtime_state.job_ids == []
assert layer.runtime_state.job_offsets == {}
@@ -27,7 +27,9 @@ def test_default_layer_providers_register_shell_layer_with_configured_token_fact
return factory
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
monkeypatch.setattr(
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
)
providers = create_default_layer_providers(
shellctl_entrypoint="http://shellctl.example",
@@ -56,7 +58,9 @@ def test_default_layer_providers_keep_empty_shellctl_token_by_default(
return factory
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
monkeypatch.setattr(
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
)
providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example")
shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID)
@@ -684,7 +684,8 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers(
),
)
layer_providers = tuple(
provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
provider
for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID
) + (shell_provider,)