diffusion_models.utils.trainer.Trainer

class diffusion_models.utils.trainer.Trainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]

Bases: object

Trainer Class that trains 1 model instance on 1 device, suited for distributed training.

__init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]

Constructor of Trainer Class.

Parameters:
  • model (Module) – instance of nn.Module to be copied to a GPU

  • train_data (Dataset) – Dataset instance

  • loss_func (Callable) – criterion to determine the loss

  • optimizer (Optimizer) – torch.optim instance with model.parameters and learning rate passed

  • gpu_id (int) – int in range [0, num_GPUs], value does not matter if device_type!=”cuda”

  • num_gpus (int) – does not matter if device_type!=”cuda”

  • save_every (int) – checkpoint model & upload data to wandb every save_every epoch

  • checkpoint_folder (str) – where to save checkpoints to

  • device_type (Literal['cuda', 'mps', 'cpu']) – specify in case not training no CUDA capable device

  • log_wandb (bool) – whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)

Methods

__init__(model, train_data, loss_func, ...)

Constructor of Trainer Class.

load_checkpoint(checkpoint_path)

train(max_epochs)

Train method of Trainer class.

train(max_epochs)[source]

Train method of Trainer class.

Parameters:

max_epochs (int) – how many epochs to train the model