diffusion_models.models.nnΒΆ

Various utilities for neural networks.

Functions

avg_pool_nd(dims, *args, **kwargs)

Create a 1D, 2D, or 3D average pooling module.

checkpoint(func, inputs, params, flag)

Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass.

conv_nd(dims, *args, **kwargs)

Create a 1D, 2D, or 3D convolution module.

linear(*args, **kwargs)

Create a linear module.

mean_flat(tensor)

Take the mean over all non-batch dimensions.

normalization(channels)

Make a standard normalization layer.

scale_module(module, scale)

Scale the parameters of a module and return it.

timestep_embedding(timesteps, dim[, max_period])

Create sinusoidal timestep embeddings.

update_ema(target_params, source_params[, rate])

Update target parameters to be closer to those of source parameters using an exponential moving average.

zero_module(module)

Zero out the parameters of a module and return it.

Classes

CheckpointFunction(*args, **kwargs)

GroupNorm32(*args, **kwargs)

SiLU(*args, **kwargs)