diffusion_models.utils.trainer.GenerativeTrainer

class diffusion_models.utils.trainer.GenerativeTrainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb, num_samples, mixed_precision, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]

Bases: Trainer

__init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb, num_samples, mixed_precision, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]

Constructor of GenerativeTrainer class.

Parameters:

model (Module) – instance of nn.Module, must implement a sample(num_samples: int) method

Methods

__init__(model, train_data, loss_func, ...)

Constructor of GenerativeTrainer class.

get_samples(num_samples)

load_checkpoint(checkpoint_path)

train(max_epochs)

Train method of Trainer class.

train(max_epochs)[source]

Train method of Trainer class.

Parameters:

max_epochs (int) – how many epochs to train the model