diffusion_models.utils.trainer.DiscriminativeTrainer¶
- class diffusion_models.utils.trainer.DiscriminativeTrainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None)[source]¶
Bases:
Trainer- __init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb=True, mixed_precision=False, gradient_accumulation_rate=1, lr_scheduler=None)[source]¶
Constructor of Trainer Class.
- Parameters:
model (
Module) – instance of nn.Module to be copied to a GPUtrain_data (
Dataset) – Dataset instanceloss_func (
Callable[...,Any]) – criterion to determine the lossoptimizer (
Optimizer) – torch.optim instance with model.parameters and learning rate passedgpu_id (
int) – int in range [0, num_GPUs], value does not matter if device_type!=”cuda”num_gpus (
int) – does not matter if device_type!=”cuda”save_every (
int) – checkpoint model & upload data to wandb every save_every epochcheckpoint_folder (
str) – where to save checkpoints todevice_type (
Literal['cuda','mps','cpu']) – specify in case not training no CUDA capable devicelog_wandb (
bool) – whether to log to wandb; requires that initialization of wandb process has been done on GPU 0 (and on this GPU only!)
Methods
__init__(model, train_data, loss_func, ...)Constructor of Trainer Class.
load_checkpoint(checkpoint_path)train(max_epochs)Train method of Trainer class.
- train(max_epochs)¶
Train method of Trainer class.
- Parameters:
max_epochs (
int) – how many epochs to train the model