diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 6026be9bf9..5a965208c6 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse from controllers.common.helpers import FileInfo from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.app.wraps import get_app_model +from controllers.console.app.wraps import get_app_model, with_session from controllers.console.workspace.models import LoadBalancingPayload from controllers.console.wraps import ( account_initialization_required, @@ -26,7 +26,6 @@ from controllers.console.wraps import ( is_admin_or_owner_required, setup_required, ) -from core.db.session_factory import session_factory from core.ops.ops_trace_manager import OpsTraceManager from core.rag.entities import PreProcessingRule, Rule, Segmentation from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -852,11 +851,11 @@ class AppTraceApi(Resource): @setup_required @login_required @account_initialization_required + @with_session @get_app_model - def get(self, app_model): + def get(self, session: Session, app_model: App): """Get app trace""" - with session_factory.create_session() as session: - app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session) return app_trace_config diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index c9cf08072a..8cb0653334 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -1,16 +1,38 @@ +"""Controller decorators for console app resources. + +`with_session` opens one SQLAlchemy session for a request handler and injects it +as the first argument after `self`. Handlers use a transaction by default so +migrated write paths keep commit/rollback handling; pure read handlers may opt +out with `write=False`. App-loading decorators prefer that injected session when +present, while still supporting existing handlers that have not been migrated +yet and still rely on Flask-SQLAlchemy's scoped `db.session`. +""" + from collections.abc import Callable from functools import wraps -from typing import overload +from typing import Concatenate, cast, overload from sqlalchemy import select +from sqlalchemy.orm import Session from controllers.console.app.error import AppNotFoundError +from core.db.session_factory import session_factory from extensions.ext_database import db from libs.login import current_account_with_tenant from models import App, AppMode -def _load_app_model(app_id: str) -> App | None: +def _load_app_model(session: Session, app_id: str) -> App | None: + """Load the tenant-scoped app row with the request session owned by `with_session`.""" + _, current_tenant_id = current_account_with_tenant() + app_model = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) + ) + return app_model + + +def _load_app_model_from_scoped_session(app_id: str) -> App | None: + """Load the app row for legacy handlers that have not adopted request session injection yet.""" _, current_tenant_id = current_account_with_tenant() app_model = db.session.scalar( select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1) @@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None: return app_model +@overload +def with_session[T, **P, R]( + view: Callable[Concatenate[T, Session, P], R], + *, + write: bool = True, +) -> Callable[Concatenate[T, P], R]: ... + + +@overload +def with_session[T, **P, R]( + view: None = None, + *, + write: bool = True, +) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ... + + +def with_session[T, **P, R]( + view: Callable[Concatenate[T, Session, P], R] | None = None, + *, + write: bool = True, +) -> ( + Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]] +): + """Inject a request-scoped session, using a transaction only for write handlers.""" + + def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]: + @wraps(view) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R: + if write: + with session_factory.get_session_maker().begin() as session: + return view(self, session, *args, **kwargs) + + with session_factory.create_session() as session: + return view(self, session, *args, **kwargs) + + return wrapper + + if view is None: + return decorator + return decorator(view) + + +def _get_injected_session(args: tuple[object, ...]) -> Session | None: + """Return the request session inserted by `with_session`, if this handler has been migrated.""" + if len(args) < 2: + return None + + candidate = args[1] + if isinstance(candidate, Session): + return candidate + + if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"): + return cast(Session, candidate) + + return None + + @overload def get_app_model[**P, R]( view: Callable[P, R], @@ -44,6 +123,13 @@ def get_app_model[**P, R]( *, mode: AppMode | list[AppMode] | None = None, ) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + """Inject the App model for handlers that receive an `app_id` path parameter. + + New handlers may compose `@with_session` above this decorator so the app row + is loaded through the same request-scoped session used by the controller. + Existing handlers continue to work through `db.session` until migrated. + """ + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: @@ -55,7 +141,11 @@ def get_app_model[**P, R]( del kwargs["app_id"] - app_model = _load_app_model(app_id) + session = _get_injected_session(args) + if session is None: + app_model = _load_app_model_from_scoped_session(app_id) + else: + app_model = _load_app_model(session, app_id) if not app_model: raise AppNotFoundError() diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index 90131fe98d..c6905455ab 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -90,7 +90,7 @@ class TestChatMessageApiPermissions: # Mock app loading mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) # Mock current user monkeypatch.setattr(completion_api, "current_user", mock_account) @@ -139,7 +139,7 @@ class TestChatMessageApiPermissions: """Ensure GET chat-messages endpoint enforces edit permissions.""" mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) conversation_id = uuid.uuid4() created_at = naive_utc_now() diff --git a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py index c4db0d5111..93310ad380 100644 --- a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py +++ b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py @@ -145,7 +145,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) mock_export_feedbacks = mock.Mock(return_value="mock csv response") monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks) @@ -179,7 +179,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) # Create mock CSV response mock_csv_content = ( @@ -220,7 +220,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) mock_json_response = { "export_info": { @@ -264,7 +264,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) mock_export_feedbacks = mock.Mock(return_value="mock filtered response") monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks) @@ -305,7 +305,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) # Mock the service to raise ValueError for invalid date mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format")) @@ -330,7 +330,7 @@ class TestFeedbackExportApi: # Setup mocks mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) # Mock the service to raise an exception mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed")) diff --git a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py index ab08c7a6d8..3634034c81 100644 --- a/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_model_config_permissions.py @@ -86,7 +86,7 @@ class TestModelConfigResourcePermissions: # Mock app loading mock_load_app_model = mock.Mock(return_value=mock_app_model) - monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model) + monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model) # Mock current user monkeypatch.setattr(model_config_api, "current_user", mock_account) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py index baa21999f9..d29946b65e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -49,7 +49,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD") monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id)) - monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model) monkeypatch.setattr(workflow_comment_module, "current_user", account) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py index 86a3b2bd93..564682b1b3 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_human_input_debug_api.py @@ -44,7 +44,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app monkeypatch.delenv("INIT_PASSWORD", raising=False) # Avoid hitting the database when resolving the app model - monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model) + monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model) @dataclass diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py index b5f751f5a5..d46d22c5a2 100644 --- a/api/tests/unit_tests/controllers/console/app/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -9,6 +9,89 @@ from controllers.console.app.error import AppNotFoundError from models.model import AppMode +class FakeSession: + app_model: object | None + committed: bool + rolled_back: bool + closed: bool + scalar_called: bool + + def __init__(self, app_model: object | None = None) -> None: + self.app_model = app_model + self.committed = False + self.rolled_back = False + self.closed = False + self.scalar_called = False + + def scalar(self, *_args: object, **_kwargs: object) -> object | None: + self.scalar_called = True + return self.app_model + + def commit(self) -> None: + self.committed = True + + def rollback(self) -> None: + self.rolled_back = True + + +class FakeSessionBegin: + session: FakeSession + entered: bool + exited: bool + exc_type: object | None + + def __init__(self, session: FakeSession) -> None: + self.session = session + self.entered = False + self.exited = False + self.exc_type = None + + def __enter__(self) -> FakeSession: + self.entered = True + return self.session + + def __exit__(self, exc_type: object | None, *_args: object) -> None: + self.exited = True + self.exc_type = exc_type + if exc_type is None: + self.session.commit() + else: + self.session.rollback() + self.session.closed = True + + +class FakeSessionContext: + session: FakeSession + entered: bool + exited: bool + exc_type: object | None + + def __init__(self, session: FakeSession) -> None: + self.session = session + self.entered = False + self.exited = False + self.exc_type = None + + def __enter__(self) -> FakeSession: + self.entered = True + return self.session + + def __exit__(self, exc_type: object | None, *_args: object) -> None: + self.exited = True + self.exc_type = exc_type + self.session.closed = True + + +class FakeSessionMaker: + begin_context: FakeSessionBegin + + def __init__(self, session: FakeSession) -> None: + self.begin_context = FakeSessionBegin(session) + + def begin(self) -> FakeSessionBegin: + return self.begin_context + + def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) @@ -41,3 +124,95 @@ def test_get_app_model_requires_app_id() -> None: with pytest.raises(ValueError): handler() + + +def test_with_session_defaults_to_write_session_for_get_app_model(monkeypatch: pytest.MonkeyPatch) -> None: + app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") + session = FakeSession(app_model) + session_maker = FakeSessionMaker(session) + monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker) + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr( + wraps_module.db, + "session", + SimpleNamespace(scalar=lambda *_args, **_kwargs: pytest.fail("db.session should not be used")), + ) + + class Handler: + @wraps_module.with_session + @wraps_module.get_app_model + def get(self, injected_session, app_model): + assert injected_session is session + return app_model.id + + assert Handler().get(app_id="app-1") == "app-1" + assert session.scalar_called + assert session.committed + assert not session.rolled_back + assert session.closed + assert session_maker.begin_context.entered + assert session_maker.begin_context.exited + assert session_maker.begin_context.exc_type is None + + +def test_with_session_read_mode_does_not_commit(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + session_context = FakeSessionContext(session) + monkeypatch.setattr(wraps_module.session_factory, "create_session", lambda: session_context) + + class Handler: + @wraps_module.with_session(write=False) + def get(self, injected_session): + assert injected_session is session + return "ok" + + assert Handler().get() == "ok" + + assert session.closed + assert not session.committed + assert not session.rolled_back + assert session_context.entered + assert session_context.exited + assert session_context.exc_type is None + + +def test_with_session_write_commits_on_success(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + session_maker = FakeSessionMaker(session) + monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker) + + class Handler: + @wraps_module.with_session(write=True) + def post(self, injected_session): + assert injected_session is session + return "ok" + + assert Handler().post() == "ok" + + assert session.closed + assert session.committed + assert not session.rolled_back + assert session_maker.begin_context.entered + assert session_maker.begin_context.exited + assert session_maker.begin_context.exc_type is None + + +def test_with_session_write_rolls_back_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + session = FakeSession() + session_maker = FakeSessionMaker(session) + monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker) + + class Handler: + @wraps_module.with_session(write=True) + def get(self, _session): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + Handler().get() + + assert session.closed + assert not session.committed + assert session.rolled_back + assert session_maker.begin_context.entered + assert session_maker.begin_context.exited + assert session_maker.begin_context.exc_type is RuntimeError