mirror of
https://github.com/langgenius/dify.git
synced 2026-06-03 08:16:37 +08:00
refactor(api): migrate tenant/user via DI for several endpoints (#36971)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -11,10 +11,10 @@ from pydantic import Field
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
|
||||
@@ -49,8 +49,8 @@ class AgentAppReferencingWorkflowsResource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
|
||||
tenant_id=tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
@@ -18,9 +18,15 @@ from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_feature_service import AgentAppFeatureConfigService
|
||||
|
||||
@@ -65,9 +71,9 @@ class AgentAppFeatureConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AgentAppFeaturesRequest.model_validate(console_ns.payload)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model,
|
||||
|
||||
@@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@@ -32,8 +39,9 @@ class Subscription(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
@@ -45,8 +53,9 @@ class Invoices(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@@ -63,9 +72,8 @@ class PartnerTenants(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, partner_key: str):
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
|
||||
@@ -3,11 +3,18 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
@@ -29,8 +36,9 @@ class ComplianceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
|
||||
@@ -11,10 +11,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@@ -35,9 +38,10 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@@ -79,10 +83,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
|
||||
@@ -14,7 +14,13 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, model_validate, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
model_validate,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@@ -23,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
@@ -48,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
|
||||
"""Ensure a console form token resolves only inside the current tenant."""
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@@ -62,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
@@ -73,15 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(self, payload: HumanInputFormSubmitPayload, form_token: str):
|
||||
def post(
|
||||
self,
|
||||
payload: HumanInputFormSubmitPayload,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
form_token: str,
|
||||
):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@@ -95,14 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
@@ -126,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@@ -134,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
|
||||
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||
|
||||
@@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from libs.login import login_required
|
||||
from models import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
@@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
+28
-47
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -17,21 +19,26 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
class _WrappedCallable(Protocol):
|
||||
__wrapped__: Callable[..., object]
|
||||
|
||||
|
||||
def unwrap(func: Callable[..., object]) -> Callable[..., object]:
|
||||
current: Callable[..., object] | _WrappedCallable = func
|
||||
while hasattr(current, "__wrapped__"):
|
||||
current = cast(_WrappedCallable, current).__wrapped__
|
||||
return cast(Callable[..., object], current)
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def _valid_payload(self):
|
||||
def _valid_payload(self) -> dict[str, str]:
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
def test_post_success(self, app: Flask) -> None:
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -45,21 +52,17 @@ class TestCreateRagPipelineDatasetApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user))
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app: Flask):
|
||||
def test_post_forbidden_non_editor(self, app: Flask) -> None:
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -69,15 +72,11 @@ class TestCreateRagPipelineDatasetApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "tenant-1", user)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app: Flask):
|
||||
def test_post_dataset_name_duplicate(self, app: Flask) -> None:
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -90,43 +89,35 @@ class TestCreateRagPipelineDatasetApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
method(api, "tenant-1", user)
|
||||
|
||||
def test_post_invalid_payload(self, app: Flask):
|
||||
def test_post_invalid_payload(self, app: Flask) -> None:
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
payload: dict[str, str] = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "tenant-1", user)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
def test_post_success(self, app: Flask) -> None:
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -135,10 +126,6 @@ class TestCreateEmptyRagPipelineDatasetApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
@@ -148,23 +135,17 @@ class TestCreateEmptyRagPipelineDatasetApi:
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = cast(tuple[dict[str, str], int], method(api, "tenant-1", user))
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app: Flask):
|
||||
def test_post_forbidden_non_editor(self, app: Flask) -> None:
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "tenant-1", user)
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -92,7 +92,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -116,7 +116,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -158,7 +158,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -189,7 +189,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -215,7 +215,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
@@ -241,7 +241,7 @@ class TestPartnerTenants:
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.billing.billing.current_account_with_tenant",
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(mock_account, "tenant-456"),
|
||||
),
|
||||
patch("libs.login._get_user", return_value=mock_account),
|
||||
|
||||
@@ -43,10 +43,9 @@ def test_jsonify_form_definition() -> None:
|
||||
|
||||
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
form = SimpleNamespace(tenant_id="tenant-1")
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form)
|
||||
ConsoleHumanInputFormApi._ensure_console_access(form, "tenant-2")
|
||||
|
||||
|
||||
def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -62,14 +61,13 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
response = handler(api, form_token="token")
|
||||
response = handler(api, "tenant-1", form_token="token")
|
||||
|
||||
payload = json.loads(response.get_data(as_text=True))
|
||||
assert payload["fields"] == ["a"]
|
||||
@@ -84,7 +82,6 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
@@ -92,7 +89,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, form_token="token")
|
||||
handler(api, "tenant-1", form_token="token")
|
||||
|
||||
|
||||
def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -106,10 +103,6 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
@@ -124,6 +117,8 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
|
||||
handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
"tenant-1",
|
||||
SimpleNamespace(id="user-1"),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
@@ -139,10 +134,6 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
|
||||
return form
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
@@ -157,6 +148,8 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
|
||||
handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
"tenant-1",
|
||||
SimpleNamespace(id="user-1"),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
@@ -176,10 +169,6 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
@@ -193,6 +182,8 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
response = handler(
|
||||
api,
|
||||
HumanInputFormSubmitPayload.model_validate({"inputs": {"content": "ok"}, "action": "approve"}),
|
||||
"tenant-1",
|
||||
SimpleNamespace(id="user-1"),
|
||||
form_token="token",
|
||||
)
|
||||
|
||||
@@ -216,10 +207,6 @@ def test_post_form_decorated_success_validates_request_body(app: Flask, monkeypa
|
||||
submit_mock(**kwargs)
|
||||
|
||||
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
@@ -254,10 +241,6 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
@@ -265,7 +248,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -285,10 +268,6 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
@@ -296,7 +275,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -316,10 +295,6 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_args, **_kwargs: _RepoStub(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="u1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
@@ -327,7 +302,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
handler(api, workflow_run_id="run-1")
|
||||
handler(api, "t1", SimpleNamespace(id="u1"), workflow_run_id="run-1")
|
||||
|
||||
|
||||
def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -364,17 +339,13 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -
|
||||
"workflow_run_result_to_finish_response",
|
||||
lambda **_kwargs: response_obj,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.human_input_form.current_account_with_tenant",
|
||||
lambda: (SimpleNamespace(id="user-1"), "t1"),
|
||||
)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
response = handler(api, workflow_run_id="run-1")
|
||||
response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
assert "data" in response.get_data(as_text=True)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.error import AccountNotFound
|
||||
from controllers.console.workspace.agent_providers import (
|
||||
AgentProviderApi,
|
||||
AgentProviderListApi,
|
||||
@@ -27,16 +25,12 @@ class TestAgentProviderListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=providers,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, user)
|
||||
|
||||
assert result == providers
|
||||
|
||||
@@ -49,33 +43,15 @@ class TestAgentProviderListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, user)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_account_not_found(self, app: Flask):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAgentProviderApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
@@ -89,16 +65,12 @@ class TestAgentProviderApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=provider_data,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
result = method(api, tenant_id, user, provider_name)
|
||||
|
||||
assert result == provider_data
|
||||
|
||||
@@ -112,29 +84,11 @@ class TestAgentProviderApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
result = method(api, tenant_id, user, provider_name)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_account_not_found(self, app: Flask):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api, "openai")
|
||||
|
||||
@@ -63,7 +63,9 @@ def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||
|
||||
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||
user = _mock_user(role)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
from controllers.console import wraps
|
||||
|
||||
monkeypatch.setattr(wraps, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
mock_service = MagicMock()
|
||||
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||
return mock_service
|
||||
|
||||
Generated
+1
@@ -1300,6 +1300,7 @@ requires-dist = [
|
||||
{ name = "pydantic-ai-slim", extras = ["anthropic", "google", "openai"], marker = "extra == 'server'", specifier = ">=1.85.1,<2.0.0" },
|
||||
{ name = "pydantic-settings", marker = "extra == 'server'", specifier = ">=2.12.0,<3.0.0" },
|
||||
{ name = "redis", marker = "extra == 'server'", specifier = ">=7.4.0,<8.0.0" },
|
||||
{ name = "shell-session-manager", marker = "extra == 'server'", specifier = "==2.1.1" },
|
||||
{ name = "typing-extensions", specifier = ">=4.12.2,<5.0.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], marker = "extra == 'server'", specifier = "==0.46.0" },
|
||||
]
|
||||
|
||||
@@ -256,9 +256,7 @@ class DifyShellRuntimeState(BaseModel):
|
||||
raise ValueError("workspace_cwd requires a matching session_id.")
|
||||
expected_workspace = _workspace_cwd(self.session_id)
|
||||
if self.workspace_cwd != expected_workspace:
|
||||
raise ValueError(
|
||||
f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}."
|
||||
)
|
||||
raise ValueError(f"workspace_cwd must equal {expected_workspace!r} for session_id {self.session_id!r}.")
|
||||
unknown_offset_job_ids = set(self.job_offsets) - set(self.job_ids)
|
||||
if unknown_offset_job_ids:
|
||||
names = ", ".join(sorted(unknown_offset_job_ids))
|
||||
@@ -694,12 +692,12 @@ def _workspace_mkdir_script(*, session_id: str) -> str:
|
||||
of silently reusing another session's workspace.
|
||||
"""
|
||||
safe_session_id = _validated_session_id(session_id)
|
||||
workspace_dir = f'$HOME/workspace/{safe_session_id}'
|
||||
workspace_dir = f"$HOME/workspace/{safe_session_id}"
|
||||
return (
|
||||
'mkdir -p "$HOME/workspace"; '
|
||||
f'if mkdir "{workspace_dir}"; then exit 0; fi; '
|
||||
f'if [ -e "{workspace_dir}" ]; then exit {_WORKSPACE_COLLISION_EXIT_CODE}; fi; '
|
||||
'exit 1'
|
||||
"exit 1"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -277,6 +277,7 @@ def test_shell_layer_suspend_and_resume_reuse_state_with_fresh_clients() -> None
|
||||
return next(clients)
|
||||
|
||||
compositor = Compositor([LayerNode("shell", _shell_provider(client_factory=factory))])
|
||||
|
||||
async def scenario() -> None:
|
||||
async with compositor.enter(configs={"shell": DifyShellLayerConfig()}) as run:
|
||||
shell_layer = run.get_layer("shell", DifyShellLayer)
|
||||
@@ -342,7 +343,10 @@ def test_shell_layer_delete_removes_workspace_then_force_deletes_tracked_jobs_an
|
||||
|
||||
assert client.events[:2] == [("run", 'rm -rf -- "$HOME/workspace/abc12ff"'), ("wait", "cleanup-job")]
|
||||
assert {call.job_id for call in client.delete_calls} == {"user-job", "mkdir-job", "cleanup-job"}
|
||||
assert all(client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job")) for call in client.delete_calls)
|
||||
assert all(
|
||||
client.events.index(("delete", call.job_id)) > client.events.index(("wait", "cleanup-job"))
|
||||
for call in client.delete_calls
|
||||
)
|
||||
assert all(call.force is True for call in client.delete_calls)
|
||||
assert layer.runtime_state.job_ids == []
|
||||
assert layer.runtime_state.job_offsets == {}
|
||||
|
||||
@@ -27,7 +27,9 @@ def test_default_layer_providers_register_shell_layer_with_configured_token_fact
|
||||
|
||||
return factory
|
||||
|
||||
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
|
||||
monkeypatch.setattr(
|
||||
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
|
||||
)
|
||||
|
||||
providers = create_default_layer_providers(
|
||||
shellctl_entrypoint="http://shellctl.example",
|
||||
@@ -56,7 +58,9 @@ def test_default_layer_providers_keep_empty_shellctl_token_by_default(
|
||||
|
||||
return factory
|
||||
|
||||
monkeypatch.setattr(compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory)
|
||||
monkeypatch.setattr(
|
||||
compositor_factory_module, "create_shellctl_client_factory", fake_create_shellctl_client_factory
|
||||
)
|
||||
|
||||
providers = create_default_layer_providers(shellctl_entrypoint="http://shellctl.example")
|
||||
shell_provider = next(provider for provider in providers if provider.type_id == DIFY_SHELL_LAYER_TYPE_ID)
|
||||
|
||||
@@ -684,7 +684,8 @@ def test_runner_rejects_duplicate_tool_names_between_shell_and_other_layers(
|
||||
),
|
||||
)
|
||||
layer_providers = tuple(
|
||||
provider for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
|
||||
provider
|
||||
for provider in create_default_layer_providers(shellctl_entrypoint="http://unused")
|
||||
if provider.type_id != DIFY_SHELL_LAYER_TYPE_ID
|
||||
) + (shell_provider,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user