diffusion_models.models.fp16_utilΒΆ
Helpers to train with 16-bit precision.
Functions
Convert primitive modules to float16. |
|
Convert primitive modules to float32, undoing convert_module_to_f16(). |
|
|
Copy model parameters into a (differently-shaped) list of full-precision parameters. |
|
Copy the master parameter data back into the model parameters. |
|
Copy the gradients from the model parameters into the master parameters from make_master_params(). |
|
Unflatten the master parameters to look like model_params. |
|