diffusion_models.utils.trainer.Trainer¶
- class diffusion_models.utils.trainer.Trainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Bases:
object
Trainer Class that trains 1 model instance on 1 device, suited for distributed training.
- __init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Constructor of Trainer Class.
- Parameters:
model (
Module
) – instance of nn.Module to be copied to a GPUtrain_data (
Dataset
) – Dataset instanceloss_func (
Callable
) – criterion to determine the lossoptimizer (
Optimizer
) – torch.optim instance with model.parameters and learning rate passedgpu_id (
int
) – int in range [0, num_GPUs], value does not matter if device_type!=”cuda”num_gpus (
int
) – does not matter if device_type!=”cuda”save_every (
int
) – checkpoint model & upload data to wandb every save_every epochcheckpoint_folder (
str
) – where to save checkpoints todevice_type (
Literal
['cuda'
,'mps'
,'cpu'
]) – specify in case not training no CUDA capable devicelog_wandb (
bool
) – whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)
Methods