diffusion_models.utils.trainer.DiscriminativeTrainer¶
- class diffusion_models.utils.trainer.DiscriminativeTrainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None)[source]¶
Bases:
Trainer
- __init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None)[source]¶
Constructor of Trainer Class.
- Parameters:
model (
Module
) – instance of nn.Module to be copied to a GPUtrain_data (
Dataset
) – Dataset instanceloss_func (
Callable
[...
,Any
]) – criterion to determine the lossoptimizer (
Optimizer
) – torch.optim instance with model.parameters and learning rate passedgpu_id (
int
) – int in range [0, num_GPUs], value does not matter if device_type!=”cuda”num_gpus (
int
) – does not matter if device_type!=”cuda”save_every (
int
) – checkpoint model & upload data to wandb every save_every epochcheckpoint_folder (
str
) – where to save checkpoints todevice_type (
Literal
['cuda'
,'mps'
,'cpu'
]) – specify in case not training no CUDA capable devicelog_wandb (
bool
) – whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)
Methods
__init__
(model, train_data, loss_func, ...)Constructor of Trainer Class.
load_checkpoint
(checkpoint_path)train
(max_epochs)Train method of Trainer class.
- train(max_epochs)¶
Train method of Trainer class.
- Parameters:
max_epochs (
int
) – how many epochs to train the model