diffusion_models.models.fp16_utilΒΆ

Helpers to train with 16-bit precision.

Functions

convert_module_to_f16(l)

Convert primitive modules to float16.

convert_module_to_f32(l)

Convert primitive modules to float32, undoing convert_module_to_f16().

make_master_params(model_params)

Copy model parameters into a (differently-shaped) list of full-precision parameters.

master_params_to_model_params(model_params, ...)

Copy the master parameter data back into the model parameters.

model_grads_to_master_grads(model_params, ...)

Copy the gradients from the model parameters into the master parameters from make_master_params().

unflatten_master_params(model_params, ...)

Unflatten the master parameters to look like model_params.

zero_grad(model_params)