diffusion_models.utils.trainer.Trainer¶
- class diffusion_models.utils.trainer.Trainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Bases:
objectTrainer Class that trains 1 model instance on 1 device, suited for distributed training.
- __init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Constructor of Trainer Class.
- Parameters:
model (
Module) – instance of nn.Module to be copied to a GPUtrain_data (
Dataset) – Dataset instanceloss_func (
Callable) – criterion to determine the lossoptimizer (
Optimizer) – torch.optim instance with model.parameters and learning rate passedgpu_id (
int) – int in range [0, num_GPUs], value does not matter if device_type!=”cuda”num_gpus (
int) – does not matter if device_type!=”cuda”save_every (
int) – checkpoint model & upload data to wandb every save_every epochcheckpoint_folder (
str) – where to save checkpoints todevice_type (
Literal['cuda','mps','cpu']) – specify in case not training no CUDA capable devicelog_wandb (
bool) – whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)
Methods