chore: dep inject for sql session (#36545)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
Asuka Minato
2026-05-25 23:24:58 +09:00
committed by GitHub
parent 4d6f8eba2a
commit 25da7ae0d9
8 changed files with 283 additions and 19 deletions
+4 -5
View File
@@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.app.wraps import get_app_model, with_session
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
@@ -26,7 +26,6 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -852,11 +851,11 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_session
@get_app_model
def get(self, app_model):
def get(self, session: Session, app_model: App):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
return app_trace_config
+93 -3
View File
@@ -1,16 +1,38 @@
"""Controller decorators for console app resources.
`with_session` opens one SQLAlchemy session for a request handler and injects it
as the first argument after `self`. Handlers use a transaction by default so
migrated write paths keep commit/rollback handling; pure read handlers may opt
out with `write=False`. App-loading decorators prefer that injected session when
present, while still supporting existing handlers that have not been migrated
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
"""
from collections.abc import Callable
from functools import wraps
from typing import overload
from typing import Concatenate, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.error import AppNotFoundError
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
def _load_app_model(app_id: str) -> App | None:
def _load_app_model(session: Session, app_id: str) -> App | None:
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
_, current_tenant_id = current_account_with_tenant()
app_model = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
_, current_tenant_id = current_account_with_tenant()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
@@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R],
*,
write: bool = True,
) -> Callable[Concatenate[T, P], R]: ...
@overload
def with_session[T, **P, R](
view: None = None,
*,
write: bool = True,
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R] | None = None,
*,
write: bool = True,
) -> (
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
):
"""Inject a request-scoped session, using a transaction only for write handlers."""
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if write:
with session_factory.get_session_maker().begin() as session:
return view(self, session, *args, **kwargs)
with session_factory.create_session() as session:
return view(self, session, *args, **kwargs)
return wrapper
if view is None:
return decorator
return decorator(view)
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
if len(args) < 2:
return None
candidate = args[1]
if isinstance(candidate, Session):
return candidate
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
return cast(Session, candidate)
return None
@overload
def get_app_model[**P, R](
view: Callable[P, R],
@@ -44,6 +123,13 @@ def get_app_model[**P, R](
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Inject the App model for handlers that receive an `app_id` path parameter.
New handlers may compose `@with_session` above this decorator so the app row
is loaded through the same request-scoped session used by the controller.
Existing handlers continue to work through `db.session` until migrated.
"""
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -55,7 +141,11 @@ def get_app_model[**P, R](
del kwargs["app_id"]
app_model = _load_app_model(app_id)
session = _get_injected_session(args)
if session is None:
app_model = _load_app_model_from_scoped_session(app_id)
else:
app_model = _load_app_model(session, app_id)
if not app_model:
raise AppNotFoundError()
@@ -90,7 +90,7 @@ class TestChatMessageApiPermissions:
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
# Mock current user
monkeypatch.setattr(completion_api, "current_user", mock_account)
@@ -139,7 +139,7 @@ class TestChatMessageApiPermissions:
"""Ensure GET chat-messages endpoint enforces edit permissions."""
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
conversation_id = uuid.uuid4()
created_at = naive_utc_now()
@@ -145,7 +145,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
mock_export_feedbacks = mock.Mock(return_value="mock csv response")
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
@@ -179,7 +179,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
# Create mock CSV response
mock_csv_content = (
@@ -220,7 +220,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
mock_json_response = {
"export_info": {
@@ -264,7 +264,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
mock_export_feedbacks = mock.Mock(return_value="mock filtered response")
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
@@ -305,7 +305,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
# Mock the service to raise ValueError for invalid date
mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format"))
@@ -330,7 +330,7 @@ class TestFeedbackExportApi:
# Setup mocks
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
# Mock the service to raise an exception
mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed"))
@@ -86,7 +86,7 @@ class TestModelConfigResourcePermissions:
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
# Mock current user
monkeypatch.setattr(model_config_api, "current_user", mock_account)
@@ -49,7 +49,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model)
monkeypatch.setattr(workflow_comment_module, "current_user", account)
@@ -44,7 +44,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
monkeypatch.delenv("INIT_PASSWORD", raising=False)
# Avoid hitting the database when resolving the app model
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model)
@dataclass
@@ -9,6 +9,89 @@ from controllers.console.app.error import AppNotFoundError
from models.model import AppMode
class FakeSession:
app_model: object | None
committed: bool
rolled_back: bool
closed: bool
scalar_called: bool
def __init__(self, app_model: object | None = None) -> None:
self.app_model = app_model
self.committed = False
self.rolled_back = False
self.closed = False
self.scalar_called = False
def scalar(self, *_args: object, **_kwargs: object) -> object | None:
self.scalar_called = True
return self.app_model
def commit(self) -> None:
self.committed = True
def rollback(self) -> None:
self.rolled_back = True
class FakeSessionBegin:
session: FakeSession
entered: bool
exited: bool
exc_type: object | None
def __init__(self, session: FakeSession) -> None:
self.session = session
self.entered = False
self.exited = False
self.exc_type = None
def __enter__(self) -> FakeSession:
self.entered = True
return self.session
def __exit__(self, exc_type: object | None, *_args: object) -> None:
self.exited = True
self.exc_type = exc_type
if exc_type is None:
self.session.commit()
else:
self.session.rollback()
self.session.closed = True
class FakeSessionContext:
session: FakeSession
entered: bool
exited: bool
exc_type: object | None
def __init__(self, session: FakeSession) -> None:
self.session = session
self.entered = False
self.exited = False
self.exc_type = None
def __enter__(self) -> FakeSession:
self.entered = True
return self.session
def __exit__(self, exc_type: object | None, *_args: object) -> None:
self.exited = True
self.exc_type = exc_type
self.session.closed = True
class FakeSessionMaker:
begin_context: FakeSessionBegin
def __init__(self, session: FakeSession) -> None:
self.begin_context = FakeSessionBegin(session)
def begin(self) -> FakeSessionBegin:
return self.begin_context
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
@@ -41,3 +124,95 @@ def test_get_app_model_requires_app_id() -> None:
with pytest.raises(ValueError):
handler()
def test_with_session_defaults_to_write_session_for_get_app_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
session = FakeSession(app_model)
session_maker = FakeSessionMaker(session)
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(
wraps_module.db,
"session",
SimpleNamespace(scalar=lambda *_args, **_kwargs: pytest.fail("db.session should not be used")),
)
class Handler:
@wraps_module.with_session
@wraps_module.get_app_model
def get(self, injected_session, app_model):
assert injected_session is session
return app_model.id
assert Handler().get(app_id="app-1") == "app-1"
assert session.scalar_called
assert session.committed
assert not session.rolled_back
assert session.closed
assert session_maker.begin_context.entered
assert session_maker.begin_context.exited
assert session_maker.begin_context.exc_type is None
def test_with_session_read_mode_does_not_commit(monkeypatch: pytest.MonkeyPatch) -> None:
session = FakeSession()
session_context = FakeSessionContext(session)
monkeypatch.setattr(wraps_module.session_factory, "create_session", lambda: session_context)
class Handler:
@wraps_module.with_session(write=False)
def get(self, injected_session):
assert injected_session is session
return "ok"
assert Handler().get() == "ok"
assert session.closed
assert not session.committed
assert not session.rolled_back
assert session_context.entered
assert session_context.exited
assert session_context.exc_type is None
def test_with_session_write_commits_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
session = FakeSession()
session_maker = FakeSessionMaker(session)
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
class Handler:
@wraps_module.with_session(write=True)
def post(self, injected_session):
assert injected_session is session
return "ok"
assert Handler().post() == "ok"
assert session.closed
assert session.committed
assert not session.rolled_back
assert session_maker.begin_context.entered
assert session_maker.begin_context.exited
assert session_maker.begin_context.exc_type is None
def test_with_session_write_rolls_back_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
session = FakeSession()
session_maker = FakeSessionMaker(session)
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
class Handler:
@wraps_module.with_session(write=True)
def get(self, _session):
raise RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
Handler().get()
assert session.closed
assert not session.committed
assert session.rolled_back
assert session_maker.begin_context.entered
assert session_maker.begin_context.exited
assert session_maker.begin_context.exc_type is RuntimeError