Source code for diffusion_models.models.fp16_util

"""
Helpers to train with 16-bit precision.
"""

import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors


[docs] def convert_module_to_f16(l): """ Convert primitive modules to float16. """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.half() l.bias.data = l.bias.data.half()
[docs] def convert_module_to_f32(l): """ Convert primitive modules to float32, undoing convert_module_to_f16(). """ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): l.weight.data = l.weight.data.float() l.bias.data = l.bias.data.float()
[docs] def make_master_params(model_params): """ Copy model parameters into a (differently-shaped) list of full-precision parameters. """ master_params = _flatten_dense_tensors( [param.detach().float() for param in model_params] ) master_params = nn.Parameter(master_params) master_params.requires_grad = True return [master_params]
[docs] def model_grads_to_master_grads(model_params, master_params): """ Copy the gradients from the model parameters into the master parameters from make_master_params(). """ master_params[0].grad = _flatten_dense_tensors( [param.grad.data.detach().float() for param in model_params] )
[docs] def master_params_to_model_params(model_params, master_params): """ Copy the master parameter data back into the model parameters. """ # Without copying to a list, if a generator is passed, this will # silently not copy any parameters. model_params = list(model_params) for param, master_param in zip( model_params, unflatten_master_params(model_params, master_params) ): param.detach().copy_(master_param)
[docs] def unflatten_master_params(model_params, master_params): """ Unflatten the master parameters to look like model_params. """ return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params))
[docs] def zero_grad(model_params): for param in model_params: # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group if param.grad is not None: param.grad.detach_() param.grad.zero_()