Shortcuts

Source code for quaterion.train.callbacks.cleanup_callback

from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerFn

from quaterion.train.trainable_model import TrainableModel

try:  # fix for version >= 1.9.0
    from pytorch_lightning import Callback
except ImportError:
    from pytorch_lightning.callbacks.base import Callback


[docs]class CleanupCallback(Callback):
[docs] def teardown( self, trainer: "pl.Trainer", pl_module: TrainableModel, stage: Optional[str] = None, ) -> None: if stage == TrainerFn.FITTING: # If encoders were wrapped, unwrap them pl_module.unwrap_cache() try: # fix for pl>=1.9.0 trainer.reset_train_val_dataloaders() except NotImplementedError: trainer.reset_train_dataloader() trainer.reset_test_dataloader() # Restore Data Loaders if they were modified for cache train_dataloader = trainer.train_dataloader.loaders pl_module.setup_dataloader(train_dataloader) if trainer.val_dataloaders: for val_loader in trainer.val_dataloaders: pl_module.setup_dataloader(val_loader)

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community