refactor: convert isinstance chains to match/case (part 8) (#36869)

This commit is contained in:
Evan
2026-05-31 22:11:05 +08:00
committed by GitHub
parent 480d05bc48
commit c6474a2a8b
5 changed files with 86 additions and 83 deletions
+32 -31
View File
@@ -397,39 +397,40 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
match message:
case AssistantPromptMessage():
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
case ToolPromptMessage():
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
case UserPromptMessage():
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
@@ -134,17 +134,18 @@ class AdvancedChatAppGenerateResponseConverter(
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
match sub_stream_response:
case MessageEndStreamResponse():
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
case ErrorStreamResponse():
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
case NodeStartStreamResponse() | NodeFinishStreamResponse():
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
case _:
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
+21 -20
View File
@@ -305,26 +305,27 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content.data if hasattr(content, "data") else str(content)
match content:
case str():
text += content
case TextPromptMessageContent():
text += content.data
case ImagePromptMessageContent():
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
case _:
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
@@ -77,17 +77,16 @@ class TemplateTransformer(ABC):
"""
def convert_scientific_notation(value: Any) -> Any:
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
match value:
case str() if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
try:
return float(value)
except ValueError:
pass
elif isinstance(value, dict):
return {k: convert_scientific_notation(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_scientific_notation(v) for v in value]
case dict():
return {k: convert_scientific_notation(v) for k, v in value.items()}
case list():
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result)
+14 -13
View File
@@ -72,20 +72,21 @@ def handle_mcp_request(
try:
# Dispatch request to appropriate handler based on instance type
if isinstance(request_root, mcp_types.InitializeRequest):
return create_success_response(handle_initialize(mcp_server.description))
elif isinstance(request_root, mcp_types.ListToolsRequest):
return create_success_response(
handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
match request_root:
case mcp_types.InitializeRequest():
return create_success_response(handle_initialize(mcp_server.description))
case mcp_types.ListToolsRequest():
return create_success_response(
handle_list_tools(
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
)
)
)
elif isinstance(request_root, mcp_types.CallToolRequest):
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
elif isinstance(request_root, mcp_types.PingRequest):
return create_success_response(handle_ping())
else:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
case mcp_types.CallToolRequest():
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
case mcp_types.PingRequest():
return create_success_response(handle_ping())
case _:
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
except ValueError as e:
logger.exception("Invalid params")