diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index fa844a8647..c9b5121a08 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task # type: ignore +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: