
class diffusion_models.models.openai_unet.UNetModel(*args: Any, **kwargs: Any)[source]

Bases: Module

The full UNet model with attention and timestep embedding.

  • in_channels – channels in the input Tensor.

  • model_channels – base channel count for the model.

  • out_channels – channels in the output Tensor.

  • num_res_blocks – number of residual blocks per downsample.

  • attention_resolutions – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.

  • dropout – the dropout probability.

  • channel_mult – channel multiplier for each level of the UNet.

  • conv_resample – if True, use learned convolutions for upsampling and downsampling.

  • dims – determines if the signal is 1D, 2D, or 3D.

  • num_classes – if specified (as an int), then this model will be class-conditional with num_classes classes.

  • use_checkpoint – use gradient checkpointing to reduce memory usage.

  • num_heads – the number of attention heads in each attention layer.

__init__(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, num_heads=1, num_heads_upsample=-1, use_scale_shift_norm=False)[source]


__init__(in_channels, model_channels, ...[, ...])


Convert the torso of the model to float16.


Convert the torso of the model to float32.

forward(x, timesteps[, y])

Apply the model to an input batch.

get_feature_vectors(x, timesteps[, y])

Apply the model and return all of the intermediate tensors.



Get the dtype used by the torso of the model.


Convert the torso of the model to float16.


Convert the torso of the model to float32.

forward(x, timesteps, y=None)[source]

Apply the model to an input batch.

  • x – an [N x C x …] Tensor of inputs.

  • timesteps – a 1-D batch of timesteps.

  • y – an [N] Tensor of labels, if class-conditional.


an [N x C x …] Tensor of outputs.

get_feature_vectors(x, timesteps, y=None)[source]

Apply the model and return all of the intermediate tensors.

  • x – an [N x C x …] Tensor of inputs.

  • timesteps – a 1-D batch of timesteps.

  • y – an [N] Tensor of labels, if class-conditional.


a dict with the following keys: - ‘down’: a list of hidden state tensors from downsampling. - ‘middle’: the tensor of the output of the lowest-resolution

block in the model.

  • ’up’: a list of hidden state tensors from upsampling.

property inner_dtype

Get the dtype used by the torso of the model.