From 540289e6c69a21f05962a580f77c2bdb3dff3797 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Wed, 8 Apr 2026 18:19:03 -0500 Subject: [PATCH] refactor: migrate session.query to select API in delete segment and regenerate summary tasks (#34763) --- api/tasks/delete_segment_from_index_task.py | 16 ++++++------ api/tasks/regenerate_summary_index_task.py | 28 ++++++++++----------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index a6a2dcebc8..306a23aeda 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,7 +3,7 @@ import time import click from celery import shared_task -from sqlalchemy import delete +from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -29,12 +29,12 @@ def delete_segment_from_index_task( start_at = time.perf_counter() with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logging.warning("Dataset %s not found, skipping index cleanup", dataset_id) return - dataset_document = session.query(Document).where(Document.id == document_id).first() + dataset_document = session.scalar(select(Document).where(Document.id == document_id).limit(1)) if not dataset_document: return @@ -60,11 +60,9 @@ def delete_segment_from_index_task( ) if dataset.is_multimodal: # delete segment attachment binding - segment_attachment_bindings = ( - session.query(SegmentAttachmentBinding) - .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) - .all() - ) + segment_attachment_bindings = session.scalars( + select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + ).all() if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) @@ -77,7 +75,7 @@ def delete_segment_from_index_task( session.execute(segment_attachment_bind_delete_stmt) # delete upload file - session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + session.execute(delete(UploadFile).where(UploadFile.id.in_(attachment_ids))) session.commit() end_at = time.perf_counter() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index 6f490ab7ea..e794195c92 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -47,7 +47,7 @@ def regenerate_summary_index_task( try: with session_factory.create_session() as session: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red")) return @@ -84,8 +84,8 @@ def regenerate_summary_index_task( # For embedding_model change: directly query all segments with existing summaries # Don't require document indexing_status == "completed" # Include summaries with status "completed" or "error" (if they have content) - segments_with_summaries = ( - session.query(DocumentSegment, DocumentSegmentSummary) + segments_with_summaries = session.execute( + select(DocumentSegment, DocumentSegmentSummary) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -110,8 +110,7 @@ def regenerate_summary_index_task( DatasetDocument.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) .order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments_with_summaries: logger.info( @@ -215,8 +214,8 @@ def regenerate_summary_index_task( try: # Get all segments with existing summaries - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .join( DocumentSegmentSummary, DocumentSegment.id == DocumentSegmentSummary.chunk_id, @@ -229,8 +228,7 @@ def regenerate_summary_index_task( DocumentSegmentSummary.dataset_id == dataset_id, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if not segments: continue @@ -245,13 +243,13 @@ def regenerate_summary_index_task( summary_record = None try: # Get existing summary record - summary_record = ( - session.query(DocumentSegmentSummary) - .filter_by( - chunk_id=segment.id, - dataset_id=dataset_id, + summary_record = session.scalar( + select(DocumentSegmentSummary) + .where( + DocumentSegmentSummary.chunk_id == segment.id, + DocumentSegmentSummary.dataset_id == dataset_id, ) - .first() + .limit(1) ) if not summary_record: