diffusion_models.models.openai_unet.UNetModel

class diffusion_models.models.openai_unet.UNetModel(*args: Any, **kwargs: Any)[source]

Bases: Module

The full UNet model with attention and timestep embedding.

Parameters:
  • in_channels – channels in the input Tensor.

  • model_channels – base channel count for the model.

  • out_channels – channels in the output Tensor.

  • num_res_blocks – number of residual blocks per downsample.

  • attention_resolutions – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.

  • dropout – the dropout probability.

  • channel_mult – channel multiplier for each level of the UNet.

  • conv_resample – if True, use learned convolutions for upsampling and downsampling.

  • dims – determines if the signal is 1D, 2D, or 3D.

  • num_classes – if specified (as an int), then this model will be class-conditional with num_classes classes.

  • use_checkpoint – use gradient checkpointing to reduce memory usage.

  • num_heads – the number of attention heads in each attention layer.

__init__(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, num_heads=1, num_heads_upsample=-1, use_scale_shift_norm=False)[source]

Methods

__init__(in_channels, model_channels, ...[, ...])

convert_to_fp16()

Convert the torso of the model to float16.

convert_to_fp32()

Convert the torso of the model to float32.

forward(x, timesteps[, y])

Apply the model to an input batch.

get_feature_vectors(x, timesteps[, y])

Apply the model and return all of the intermediate tensors.

Attributes

inner_dtype

Get the dtype used by the torso of the model.

convert_to_fp16()[source]

Convert the torso of the model to float16.

convert_to_fp32()[source]

Convert the torso of the model to float32.

forward(x, timesteps, y=None)[source]

Apply the model to an input batch.

Parameters:
  • x – an [N x C x …] Tensor of inputs.

  • timesteps – a 1-D batch of timesteps.

  • y – an [N] Tensor of labels, if class-conditional.

Returns:

an [N x C x …] Tensor of outputs.

get_feature_vectors(x, timesteps, y=None)[source]

Apply the model and return all of the intermediate tensors.

Parameters:
  • x – an [N x C x …] Tensor of inputs.

  • timesteps – a 1-D batch of timesteps.

  • y – an [N] Tensor of labels, if class-conditional.

Returns:

a dict with the following keys: - ‘down’: a list of hidden state tensors from downsampling. - ‘middle’: the tensor of the output of the lowest-resolution

block in the model.

  • ’up’: a list of hidden state tensors from upsampling.

property inner_dtype

Get the dtype used by the torso of the model.