diffusion_models.utils.trainer.GenerativeTrainer¶
- class diffusion_models.utils.trainer.GenerativeTrainer(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb, num_samples, mixed_precision, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Bases:
Trainer
- __init__(model, train_data, loss_func, optimizer, gpu_id, num_gpus, batch_size, save_every, checkpoint_folder, device_type, log_wandb, num_samples, mixed_precision, gradient_accumulation_rate=1, lr_scheduler=None, k_space=False)[source]¶
Constructor of GenerativeTrainer class.
- Parameters:
model (
Module
) – instance of nn.Module, must implement a sample(num_samples: int) method
Methods