refactor: convert isinstance chains to match/case (part 5) (#36503)

This commit is contained in:
Evan
2026-05-31 23:08:59 +08:00
committed by GitHub
parent 8e5f09091b
commit df6b5be50a
6 changed files with 102 additions and 84 deletions
@@ -75,20 +75,23 @@ class AliyunDataTrace(BaseTraceInstance):
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
pass
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
pass
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
pass
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
pass
case _:
pass
def api_check(self):
return self.trace_client.api_check()
@@ -708,20 +708,23 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
self.moderation_trace(trace_info)
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
self.generate_name_trace(trace_info)
case _:
pass
except Exception as e:
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
@@ -107,20 +107,23 @@ class LangFuseDataTrace(BaseTraceInstance):
return start_time + timedelta(seconds=ttft_seconds)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
self.moderation_trace(trace_info)
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
self.generate_name_trace(trace_info)
case _:
pass
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.workflow_run_id
@@ -48,20 +48,23 @@ class LangSmithDataTrace(BaseTraceInstance):
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
self.moderation_trace(trace_info)
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
self.generate_name_trace(trace_info)
case _:
pass
def workflow_trace(self, trace_info: WorkflowTraceInfo):
# trace_id must equal the root run's run_id (LangSmith protocol); external trace_id
@@ -96,20 +96,23 @@ class OpikDataTrace(BaseTraceInstance):
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
self.moderation_trace(trace_info)
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
self.generate_name_trace(trace_info)
case _:
pass
def workflow_trace(self, trace_info: WorkflowTraceInfo):
workflow_metadata = wrap_metadata(
@@ -80,20 +80,23 @@ class WeaveDataTrace(BaseTraceInstance):
def trace(self, trace_info: BaseTraceInfo):
logger.debug("Trace info: %s", trace_info)
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
match trace_info:
case WorkflowTraceInfo():
self.workflow_trace(trace_info)
case MessageTraceInfo():
self.message_trace(trace_info)
case ModerationTraceInfo():
self.moderation_trace(trace_info)
case SuggestedQuestionTraceInfo():
self.suggested_question_trace(trace_info)
case DatasetRetrievalTraceInfo():
self.dataset_retrieval_trace(trace_info)
case ToolTraceInfo():
self.tool_trace(trace_info)
case GenerateNameTraceInfo():
self.generate_name_trace(trace_info)
case _:
pass
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id