From d8571ce965ab9bdbc2c186fd28bda8e258db4b01 Mon Sep 17 00:00:00 2001 From: Evan <2869018789@qq.com> Date: Sun, 31 May 2026 22:44:17 +0800 Subject: [PATCH] refactor: convert isinstance chains to match/case (part 4) (#36274) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato --- api/core/mcp/server/streamable_http.py | 13 +- api/core/mcp/session/base_session.py | 173 ++++++++++-------- api/core/rag/retrieval/dataset_retrieval.py | 13 +- api/models/workflow.py | 156 ++++++++-------- .../app_generate/workflow_execute_task.py | 49 ++--- 5 files changed, 215 insertions(+), 189 deletions(-) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index fadb6fa2d6..08b4ed0e19 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -203,12 +203,13 @@ def extract_answer_from_response(app: App, response: Any) -> str: """Extract answer from app generate response""" answer = "" - if isinstance(response, RateLimitGenerator): - answer = process_streaming_response(response) - elif isinstance(response, Mapping): - answer = process_mapping_response(app, response) - else: - logger.warning("Unexpected response type: %s", type(response)) + match response: + case RateLimitGenerator(): + answer = process_streaming_response(response) + case Mapping(): + answer = process_mapping_response(app, response) + case _: + logger.warning("Unexpected response type: %s", type(response)) return answer diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 70d45b15c4..00f6ac8fa3 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -240,30 +240,31 @@ class BaseSession[ self.check_receiver_status() continue - if response_or_error is None: - raise MCPConnectionError( - ErrorData( - code=500, - message="No response received", - ) - ) - elif isinstance(response_or_error, HTTPStatusError): - # HTTPStatusError from streamable_client with preserved response object - if response_or_error.response.status_code == 401: - raise MCPAuthError(response=response_or_error.response) - else: + match response_or_error: + case None: raise MCPConnectionError( - ErrorData(code=response_or_error.response.status_code, message=str(response_or_error)) + ErrorData( + code=500, + message="No response received", + ) ) - elif isinstance(response_or_error, JSONRPCError): - if response_or_error.error.code == 401: - raise MCPAuthError(message=response_or_error.error.message) - else: - raise MCPConnectionError( - ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) - ) - else: - return result_type.model_validate(response_or_error.result) + case HTTPStatusError(): + # HTTPStatusError from streamable_client with preserved response object + if response_or_error.response.status_code == 401: + raise MCPAuthError(response=response_or_error.response) + else: + raise MCPConnectionError( + ErrorData(code=response_or_error.response.status_code, message=str(response_or_error)) + ) + case JSONRPCError(): + if response_or_error.error.code == 401: + raise MCPAuthError(message=response_or_error.error.message) + else: + raise MCPConnectionError( + ErrorData(code=response_or_error.error.code, message=response_or_error.error.message) + ) + case _: + return result_type.model_validate(response_or_error.result) finally: self._response_streams.pop(request_id, None) @@ -316,65 +317,79 @@ class BaseSession[ message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) if message is None: break - if isinstance(message, HTTPStatusError): - response_queue = self._response_streams.get(self._request_id - 1) - if response_queue is not None: - # For 401 errors, pass the HTTPStatusError directly to preserve response object - if message.response.status_code == 401: - response_queue.put(message) - else: - response_queue.put( - JSONRPCError( - jsonrpc="2.0", - id=self._request_id - 1, - error=ErrorData(code=message.response.status_code, message=message.args[0]), + match message: + case HTTPStatusError(): + response_queue = self._response_streams.get(self._request_id - 1) + if response_queue is not None: + # For 401 errors, pass the HTTPStatusError directly to preserve response object + if message.response.status_code == 401: + response_queue.put(message) + else: + response_queue.put( + JSONRPCError( + jsonrpc="2.0", + id=self._request_id - 1, + error=ErrorData(code=message.response.status_code, message=message.args[0]), + ) ) - ) - else: - self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) - elif isinstance(message, Exception): - self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - - responder = RequestResponder[ReceiveRequestT, SendResultT]( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta if validated_request.root.params else None, - request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - ) - - self._in_flight[responder.request_id] = responder - self._received_request(responder) - - if not responder.completed: - self._handle_incoming(responder) - - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - self._in_flight[cancelled_id].cancel() else: - self._received_notification(notification) # type: ignore[arg-type] - self._handle_incoming(notification) # type: ignore[arg-type] - except Exception as e: - # For other validation errors, log and continue - logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) - else: # Response or error - response_queue = self._response_streams.get(message.message.root.id) - if response_queue is not None: - response_queue.put(message.message.root) - else: - self._handle_incoming(RuntimeError(f"Server Error: {message}")) + self._handle_incoming( + RuntimeError(f"Received response with an unknown request ID: {message}") + ) + case Exception(): + self._handle_incoming(message) + case SessionMessage(message=JSONRPCMessage(root=JSONRPCRequest())): + request_root = message.message.root + if not isinstance(request_root, JSONRPCRequest): + continue + + validated_request = self._receive_request_type.model_validate( + request_root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + + responder = RequestResponder[ReceiveRequestT, SendResultT]( + request_id=request_root.id, + request_meta=validated_request.root.params.meta if validated_request.root.params else None, + request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + ) + + self._in_flight[responder.request_id] = responder + self._received_request(responder) + + if not responder.completed: + self._handle_incoming(responder) + + case SessionMessage(message=JSONRPCMessage(root=JSONRPCNotification())): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + self._in_flight[cancelled_id].cancel() + else: + self._received_notification(notification) # type: ignore[arg-type] + self._handle_incoming(notification) # type: ignore[arg-type] + except Exception as e: + # For other validation errors, log and continue + logger.warning( + "Failed to validate notification: %s. Message was: %s", e, message.message.root + ) + case _: # Response or error + response_root = message.message.root + if not isinstance(response_root, (JSONRPCResponse, JSONRPCError)): + self._handle_incoming(RuntimeError(f"Server Error: {message}")) + continue + + response_queue = self._response_streams.get(response_root.id) + if response_queue is not None: + response_queue.put(response_root) + else: + self._handle_incoming(RuntimeError(f"Server Error: {message}")) except queue.Empty: continue except Exception: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 039a266f44..dae6f63b7b 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1554,12 +1554,13 @@ class DatasetRetrieval: case "≥" | ">=": filters.append(DatasetDocument.doc_metadata[metadata_name].as_float() >= value) case "in" | "not in": - if isinstance(value, str): - value_list = [v.strip() for v in value.split(",") if v.strip()] - elif isinstance(value, (list, tuple)): - value_list = [str(v) for v in value if v is not None] - else: - value_list = [str(value)] if value is not None else [] + match value: + case str(): + value_list = [v.strip() for v in value.split(",") if v.strip()] + case list() | tuple(): + value_list = [str(v) for v in value if v is not None] + case _: + value_list = [str(value)] if value is not None else [] if not value_list: # `field in []` is False, `field not in []` is True diff --git a/api/models/workflow.py b/api/models/workflow.py index 282d4a8834..b3cb921b07 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -543,13 +543,16 @@ class Workflow(Base): # bug def decrypt_func( var: VariableBase, ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: - if isinstance(var, SecretVariable): - return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): - return var - else: - # Other variable types are not supported for environment variables - raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") + match var: + case SecretVariable(): + return var.model_copy( + update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} + ) + case StringVariable() | IntegerVariable() | FloatVariable(): + return var + case _: + # Other variable types are not supported for environment variables + raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ decrypt_func(var) for var in results @@ -1638,31 +1641,32 @@ class WorkflowDraftVariable(Base): # rather than their serialized forms. # However, multiple components in the codebase depend on # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. - if isinstance(value, dict): - if not maybe_file_object(value): - return cast(Any, value) - tenant_id = _resolve_workflow_app_tenant_id(self.app_id) - return build_file_from_stored_mapping( - file_mapping=cast(dict[str, Any], value), - tenant_id=tenant_id, - ) - elif isinstance(value, list) and value: - value_list = cast(list[Any], value) - first: Any = value_list[0] - if not maybe_file_object(first): - return cast(Any, value) - tenant_id = _resolve_workflow_app_tenant_id(self.app_id) - file_list: list[File] = [] - for item in value_list: - file_list.append( - build_file_from_stored_mapping( - file_mapping=cast(dict[str, Any], item), - tenant_id=tenant_id, - ) + match value: + case dict(): + if not maybe_file_object(value): + return cast(Any, value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, ) - return cast(Any, file_list) - else: - return cast(Any, value) + case list() if value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + case _: + return cast(Any, value) def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: # Persisted draft variable rows may contain historical file payloads. @@ -1671,13 +1675,14 @@ class WorkflowDraftVariable(Base): # serialized JSON blob. match segment_type: case SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = self._rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + match value: + case File(): + return build_segment_with_type(segment_type, value) + case dict(): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + case _: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") case SegmentType.ARRAY_FILE: if not isinstance(value, list): raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") @@ -1692,25 +1697,26 @@ class WorkflowDraftVariable(Base): # structural reconstruction. Persisted draft-variable payloads should go # through `build_segment_from_serialized_value()` so file metadata is # rebuilt from canonical storage records. - if isinstance(value, dict): - if not maybe_file_object(value): - return cast(Any, value) - normalized_file = dict(value) - normalized_file.pop("tenant_id", None) - return build_file_from_mapping_without_lookup(file_mapping=normalized_file) - elif isinstance(value, list) and value: - value_list = cast(list[Any], value) - first: Any = value_list[0] - if not maybe_file_object(first): - return cast(Any, value) - file_list: list[File] = [] - for item in value_list: - normalized_file = dict(cast(dict[str, Any], item)) + match value: + case dict(): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) normalized_file.pop("tenant_id", None) - file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) - return cast(Any, file_list) - else: - return cast(Any, value) + return build_file_from_mapping_without_lookup(file_mapping=normalized_file) + case list() if value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(build_file_from_mapping_without_lookup(file_mapping=normalized_file)) + return cast(Any, file_list) + case _: + return cast(Any, value) @classmethod def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: @@ -1719,13 +1725,14 @@ class WorkflowDraftVariable(Base): # their serialized dictionary or list representations, respectively. match segment_type: case SegmentType.FILE: - if isinstance(value, File): - return build_segment_with_type(segment_type, value) - elif isinstance(value, dict): - file = cls.rebuild_file_types(value) - return build_segment_with_type(segment_type, file) - else: - raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + match value: + case File(): + return build_segment_with_type(segment_type, value) + case dict(): + file = cls.rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + case _: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") case SegmentType.ARRAY_FILE: if not isinstance(value, list): raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") @@ -2099,17 +2106,18 @@ class WorkflowPauseReason(DefaultFieldsDCMixin, TypeBase): @classmethod def from_entity(cls, *, pause_id: str, pause_reason: PauseReason) -> "WorkflowPauseReason": - if isinstance(pause_reason, HumanInputRequired): - return cls( - pause_id=pause_id, - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id=pause_reason.form_id, - node_id=pause_reason.node_id, - ) - elif isinstance(pause_reason, SchedulingPause): - return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message) - else: - raise AssertionError(f"Unknown pause reason type: {pause_reason}") + match pause_reason: + case HumanInputRequired(): + return cls( + pause_id=pause_id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=pause_reason.form_id, + node_id=pause_reason.node_id, + ) + case SchedulingPause(): + return cls(pause_id=pause_id, type_=PauseReasonType.SCHEDULED_PAUSE, message=pause_reason.message) + case _: + raise AssertionError(f"Unknown pause reason type: {pause_reason}") def to_entity(self) -> PauseReason: if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 16121cefa6..3c3b04fde0 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -52,20 +52,21 @@ class _EndUser(BaseModel): def _get_user_type_descriminator(value: Any): - if isinstance(value, (_Account, _EndUser)): - return value.TYPE - elif isinstance(value, dict): - user_type_str = value.get("TYPE") - if user_type_str is None: + match value: + case _Account() | _EndUser(): + return value.TYPE + case dict(): + user_type_str = value.get("TYPE") + if user_type_str is None: + return None + try: + user_type = _UserType(user_type_str) + except ValueError: + return None + return user_type + case _: + # return None if the discriminator value isn't found return None - try: - user_type = _UserType(user_type_str) - except ValueError: - return None - return user_type - else: - # return None if the discriminator value isn't found - return None type User = Annotated[ @@ -221,17 +222,17 @@ class _AppRunner: def _resolve_user(self) -> Account | EndUser: user_params = self._exec_params.user - if isinstance(user_params, _EndUser): - with self._session() as session: - return session.get(EndUser, user_params.end_user_id) - elif not isinstance(user_params, _Account): - raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}") - - with self._session() as session: - user: Account = session.get(Account, user_params.user_id) - user.set_tenant_id(self._exec_params.tenant_id) - - return user + match user_params: + case _EndUser(): + with self._session() as session: + return session.get(EndUser, user_params.end_user_id) + case _Account(): + with self._session() as session: + user: Account = session.get(Account, user_params.user_id) + user.set_tenant_id(self._exec_params.tenant_id) + return user + case _: + raise AssertionError(f"user should only be _Account or _EndUser, got {type(user_params)}") def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Account | EndUser | None: