diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 874fd8a7e3..95181b93cf 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -7,7 +7,7 @@ from hmac import new as hmac_new from flask import abort, request from configs import dify_config -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.model import EndUser @@ -44,6 +44,8 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]: def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]: + """Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged.""" + @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: if not dify_config.INNER_API: @@ -72,9 +74,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.get(EndUser, user_id) - - return view(*args, **kwargs) + with session_factory.create_session() as session: + kwargs["user"] = session.get(EndUser, user_id) + return view(*args, **kwargs) return decorated diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py index efe1841f08..ffe0c4e6b3 100644 --- a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -15,6 +15,7 @@ from controllers.inner_api.wraps import ( enterprise_inner_api_user_auth, plugin_inner_api_only, ) +from models.model import EndUser class TestBillingInnerApiOnly: @@ -217,10 +218,12 @@ class TestEnterpriseInnerApiUserAuth: headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"} ): with patch.object(dify_config, "INNER_API", True): - result = protected_view() + with patch("controllers.inner_api.wraps.session_factory.create_session") as mock_create_session: + result = protected_view() # Assert assert result == "no_user" + mock_create_session.assert_not_called() def test_should_inject_user_when_hmac_signature_valid(self, app: Flask): """Test that user is injected when HMAC signature is valid""" @@ -243,18 +246,26 @@ class TestEnterpriseInnerApiUserAuth: # Create mock user mock_user = MagicMock() mock_user.id = user_id + mock_session = MagicMock() + mock_session.get.return_value = mock_user + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session # Act with app.test_request_context( headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} ): with patch.object(dify_config, "INNER_API", True): - with patch("controllers.inner_api.wraps.db.session.get") as mock_get: - mock_get.return_value = mock_user + with patch( + "controllers.inner_api.wraps.session_factory.create_session", + return_value=mock_session_context, + ) as mock_create_session: result = protected_view() # Assert assert result == mock_user + mock_create_session.assert_called_once_with() + mock_session.get.assert_called_once_with(EndUser, user_id) class TestPluginInnerApiOnly: