mirror of
https://github.com/langgenius/dify.git
synced 2026-06-03 08:16:37 +08:00
refactor: convert isinstance chains to match/case (part 8) (#36869)
This commit is contained in:
@@ -397,39 +397,40 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
current_scratchpad: AgentScratchpadUnit | None = None
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
match message:
|
||||
case AssistantPromptMessage():
|
||||
if not current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
scratchpads.append(current_scratchpad)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except Exception:
|
||||
logger.exception("Failed to parse tool call from assistant message")
|
||||
case ToolPromptMessage():
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
case UserPromptMessage():
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
@@ -134,17 +134,18 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
elif isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
match sub_stream_response:
|
||||
case MessageEndStreamResponse():
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
case ErrorStreamResponse():
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
case NodeStartStreamResponse() | NodeFinishStreamResponse():
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
case _:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -305,26 +305,27 @@ class AppRunner:
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
match content:
|
||||
case str():
|
||||
text += content
|
||||
case TextPromptMessageContent():
|
||||
text += content.data
|
||||
case ImagePromptMessageContent():
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
case _:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
@@ -77,17 +77,16 @@ class TemplateTransformer(ABC):
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
match value:
|
||||
case str() if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_scientific_notation(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
case dict():
|
||||
return {k: convert_scientific_notation(v) for k, v in value.items()}
|
||||
case list():
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result)
|
||||
|
||||
@@ -72,20 +72,21 @@ def handle_mcp_request(
|
||||
|
||||
try:
|
||||
# Dispatch request to appropriate handler based on instance type
|
||||
if isinstance(request_root, mcp_types.InitializeRequest):
|
||||
return create_success_response(handle_initialize(mcp_server.description))
|
||||
elif isinstance(request_root, mcp_types.ListToolsRequest):
|
||||
return create_success_response(
|
||||
handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
match request_root:
|
||||
case mcp_types.InitializeRequest():
|
||||
return create_success_response(handle_initialize(mcp_server.description))
|
||||
case mcp_types.ListToolsRequest():
|
||||
return create_success_response(
|
||||
handle_list_tools(
|
||||
app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict
|
||||
)
|
||||
)
|
||||
)
|
||||
elif isinstance(request_root, mcp_types.CallToolRequest):
|
||||
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
|
||||
elif isinstance(request_root, mcp_types.PingRequest):
|
||||
return create_success_response(handle_ping())
|
||||
else:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
case mcp_types.CallToolRequest():
|
||||
return create_success_response(handle_call_tool(app, request, user_input_form, end_user))
|
||||
case mcp_types.PingRequest():
|
||||
return create_success_response(handle_ping())
|
||||
case _:
|
||||
return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
|
||||
Reference in New Issue
Block a user