"""Helpers to train with 16-bit precision."""importtorch.nnasnnfromtorch._utilsimport_flatten_dense_tensors,_unflatten_dense_tensors
[docs]defconvert_module_to_f16(l):""" Convert primitive modules to float16. """ifisinstance(l,(nn.Conv1d,nn.Conv2d,nn.Conv3d)):l.weight.data=l.weight.data.half()l.bias.data=l.bias.data.half()
[docs]defconvert_module_to_f32(l):""" Convert primitive modules to float32, undoing convert_module_to_f16(). """ifisinstance(l,(nn.Conv1d,nn.Conv2d,nn.Conv3d)):l.weight.data=l.weight.data.float()l.bias.data=l.bias.data.float()
[docs]defmake_master_params(model_params):""" Copy model parameters into a (differently-shaped) list of full-precision parameters. """master_params=_flatten_dense_tensors([param.detach().float()forparaminmodel_params])master_params=nn.Parameter(master_params)master_params.requires_grad=Truereturn[master_params]
[docs]defmodel_grads_to_master_grads(model_params,master_params):""" Copy the gradients from the model parameters into the master parameters from make_master_params(). """master_params[0].grad=_flatten_dense_tensors([param.grad.data.detach().float()forparaminmodel_params])
[docs]defmaster_params_to_model_params(model_params,master_params):""" Copy the master parameter data back into the model parameters. """# Without copying to a list, if a generator is passed, this will# silently not copy any parameters.model_params=list(model_params)forparam,master_paraminzip(model_params,unflatten_master_params(model_params,master_params)):param.detach().copy_(master_param)
[docs]defunflatten_master_params(model_params,master_params):""" Unflatten the master parameters to look like model_params. """return_unflatten_dense_tensors(master_params[0].detach(),tuple(tensorfortensorinmodel_params))
[docs]defzero_grad(model_params):forparaminmodel_params:# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_groupifparam.gradisnotNone:param.grad.detach_()param.grad.zero_()