mirror of
https://github.com/langgenius/dify.git
synced 2026-06-06 08:00:00 +08:00
chore: dep inject for sql session (#36545)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
@@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.app.wraps import get_app_model, with_session
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -26,7 +26,6 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -852,11 +851,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_session
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
def get(self, session: Session, app_model: App):
|
||||
"""Get app trace"""
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
|
||||
@@ -1,16 +1,38 @@
|
||||
"""Controller decorators for console app resources.
|
||||
|
||||
`with_session` opens one SQLAlchemy session for a request handler and injects it
|
||||
as the first argument after `self`. Handlers use a transaction by default so
|
||||
migrated write paths keep commit/rollback handling; pure read handlers may opt
|
||||
out with `write=False`. App-loading decorators prefer that injected session when
|
||||
present, while still supporting existing handlers that have not been migrated
|
||||
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import overload
|
||||
from typing import Concatenate, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
def _load_app_model(session: Session, app_id: str) -> App | None:
|
||||
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
|
||||
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
@@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R],
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
|
||||
|
||||
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R] | None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> (
|
||||
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
|
||||
):
|
||||
"""Inject a request-scoped session, using a transaction only for write handlers."""
|
||||
|
||||
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if write:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
|
||||
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
candidate = args[1]
|
||||
if isinstance(candidate, Session):
|
||||
return candidate
|
||||
|
||||
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
|
||||
return cast(Session, candidate)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
@@ -44,6 +123,13 @@ def get_app_model[**P, R](
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Inject the App model for handlers that receive an `app_id` path parameter.
|
||||
|
||||
New handlers may compose `@with_session` above this decorator so the app row
|
||||
is loaded through the same request-scoped session used by the controller.
|
||||
Existing handlers continue to work through `db.session` until migrated.
|
||||
"""
|
||||
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@@ -55,7 +141,11 @@ def get_app_model[**P, R](
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model(app_id)
|
||||
session = _get_injected_session(args)
|
||||
if session is None:
|
||||
app_model = _load_app_model_from_scoped_session(app_id)
|
||||
else:
|
||||
app_model = _load_app_model(session, app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
+2
-2
@@ -90,7 +90,7 @@ class TestChatMessageApiPermissions:
|
||||
# Mock app loading
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(completion_api, "current_user", mock_account)
|
||||
@@ -139,7 +139,7 @@ class TestChatMessageApiPermissions:
|
||||
"""Ensure GET chat-messages endpoint enforces edit permissions."""
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
conversation_id = uuid.uuid4()
|
||||
created_at = naive_utc_now()
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock csv response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
@@ -179,7 +179,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
# Create mock CSV response
|
||||
mock_csv_content = (
|
||||
@@ -220,7 +220,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
mock_json_response = {
|
||||
"export_info": {
|
||||
@@ -264,7 +264,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
mock_export_feedbacks = mock.Mock(return_value="mock filtered response")
|
||||
monkeypatch.setattr(FeedbackService, "export_feedbacks", mock_export_feedbacks)
|
||||
@@ -305,7 +305,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise ValueError for invalid date
|
||||
mock_export_feedbacks = mock.Mock(side_effect=ValueError("Invalid date format"))
|
||||
@@ -330,7 +330,7 @@ class TestFeedbackExportApi:
|
||||
|
||||
# Setup mocks
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
# Mock the service to raise an exception
|
||||
mock_export_feedbacks = mock.Mock(side_effect=Exception("Database connection failed"))
|
||||
|
||||
+1
-1
@@ -86,7 +86,7 @@ class TestModelConfigResourcePermissions:
|
||||
|
||||
# Mock app loading
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model_from_scoped_session", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(model_config_api, "current_user", mock_account)
|
||||
|
||||
@@ -49,7 +49,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model)
|
||||
monkeypatch.setattr(workflow_comment_module, "current_user", account)
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
# Avoid hitting the database when resolving the app model
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model_from_scoped_session", lambda _app_id: app_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,6 +9,89 @@ from controllers.console.app.error import AppNotFoundError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class FakeSession:
|
||||
app_model: object | None
|
||||
committed: bool
|
||||
rolled_back: bool
|
||||
closed: bool
|
||||
scalar_called: bool
|
||||
|
||||
def __init__(self, app_model: object | None = None) -> None:
|
||||
self.app_model = app_model
|
||||
self.committed = False
|
||||
self.rolled_back = False
|
||||
self.closed = False
|
||||
self.scalar_called = False
|
||||
|
||||
def scalar(self, *_args: object, **_kwargs: object) -> object | None:
|
||||
self.scalar_called = True
|
||||
return self.app_model
|
||||
|
||||
def commit(self) -> None:
|
||||
self.committed = True
|
||||
|
||||
def rollback(self) -> None:
|
||||
self.rolled_back = True
|
||||
|
||||
|
||||
class FakeSessionBegin:
|
||||
session: FakeSession
|
||||
entered: bool
|
||||
exited: bool
|
||||
exc_type: object | None
|
||||
|
||||
def __init__(self, session: FakeSession) -> None:
|
||||
self.session = session
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
self.exc_type = None
|
||||
|
||||
def __enter__(self) -> FakeSession:
|
||||
self.entered = True
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type: object | None, *_args: object) -> None:
|
||||
self.exited = True
|
||||
self.exc_type = exc_type
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.closed = True
|
||||
|
||||
|
||||
class FakeSessionContext:
|
||||
session: FakeSession
|
||||
entered: bool
|
||||
exited: bool
|
||||
exc_type: object | None
|
||||
|
||||
def __init__(self, session: FakeSession) -> None:
|
||||
self.session = session
|
||||
self.entered = False
|
||||
self.exited = False
|
||||
self.exc_type = None
|
||||
|
||||
def __enter__(self) -> FakeSession:
|
||||
self.entered = True
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type: object | None, *_args: object) -> None:
|
||||
self.exited = True
|
||||
self.exc_type = exc_type
|
||||
self.session.closed = True
|
||||
|
||||
|
||||
class FakeSessionMaker:
|
||||
begin_context: FakeSessionBegin
|
||||
|
||||
def __init__(self, session: FakeSession) -> None:
|
||||
self.begin_context = FakeSessionBegin(session)
|
||||
|
||||
def begin(self) -> FakeSessionBegin:
|
||||
return self.begin_context
|
||||
|
||||
|
||||
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
@@ -41,3 +124,95 @@ def test_get_app_model_requires_app_id() -> None:
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
handler()
|
||||
|
||||
|
||||
def test_with_session_defaults_to_write_session_for_get_app_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
|
||||
session = FakeSession(app_model)
|
||||
session_maker = FakeSessionMaker(session)
|
||||
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
|
||||
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
|
||||
monkeypatch.setattr(
|
||||
wraps_module.db,
|
||||
"session",
|
||||
SimpleNamespace(scalar=lambda *_args, **_kwargs: pytest.fail("db.session should not be used")),
|
||||
)
|
||||
|
||||
class Handler:
|
||||
@wraps_module.with_session
|
||||
@wraps_module.get_app_model
|
||||
def get(self, injected_session, app_model):
|
||||
assert injected_session is session
|
||||
return app_model.id
|
||||
|
||||
assert Handler().get(app_id="app-1") == "app-1"
|
||||
assert session.scalar_called
|
||||
assert session.committed
|
||||
assert not session.rolled_back
|
||||
assert session.closed
|
||||
assert session_maker.begin_context.entered
|
||||
assert session_maker.begin_context.exited
|
||||
assert session_maker.begin_context.exc_type is None
|
||||
|
||||
|
||||
def test_with_session_read_mode_does_not_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = FakeSession()
|
||||
session_context = FakeSessionContext(session)
|
||||
monkeypatch.setattr(wraps_module.session_factory, "create_session", lambda: session_context)
|
||||
|
||||
class Handler:
|
||||
@wraps_module.with_session(write=False)
|
||||
def get(self, injected_session):
|
||||
assert injected_session is session
|
||||
return "ok"
|
||||
|
||||
assert Handler().get() == "ok"
|
||||
|
||||
assert session.closed
|
||||
assert not session.committed
|
||||
assert not session.rolled_back
|
||||
assert session_context.entered
|
||||
assert session_context.exited
|
||||
assert session_context.exc_type is None
|
||||
|
||||
|
||||
def test_with_session_write_commits_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = FakeSession()
|
||||
session_maker = FakeSessionMaker(session)
|
||||
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
|
||||
|
||||
class Handler:
|
||||
@wraps_module.with_session(write=True)
|
||||
def post(self, injected_session):
|
||||
assert injected_session is session
|
||||
return "ok"
|
||||
|
||||
assert Handler().post() == "ok"
|
||||
|
||||
assert session.closed
|
||||
assert session.committed
|
||||
assert not session.rolled_back
|
||||
assert session_maker.begin_context.entered
|
||||
assert session_maker.begin_context.exited
|
||||
assert session_maker.begin_context.exc_type is None
|
||||
|
||||
|
||||
def test_with_session_write_rolls_back_on_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session = FakeSession()
|
||||
session_maker = FakeSessionMaker(session)
|
||||
monkeypatch.setattr(wraps_module.session_factory, "get_session_maker", lambda: session_maker)
|
||||
|
||||
class Handler:
|
||||
@wraps_module.with_session(write=True)
|
||||
def get(self, _session):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
Handler().get()
|
||||
|
||||
assert session.closed
|
||||
assert not session.committed
|
||||
assert session.rolled_back
|
||||
assert session_maker.begin_context.entered
|
||||
assert session_maker.begin_context.exited
|
||||
assert session_maker.begin_context.exc_type is RuntimeError
|
||||
|
||||
Reference in New Issue
Block a user