diffusion_models.models.openai_unet.UNetModel¶
- class diffusion_models.models.openai_unet.UNetModel(*args: Any, **kwargs: Any)[source]¶
Bases:
Module
The full UNet model with attention and timestep embedding.
- Parameters:
in_channels – channels in the input Tensor.
model_channels – base channel count for the model.
out_channels – channels in the output Tensor.
num_res_blocks – number of residual blocks per downsample.
attention_resolutions – a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used.
dropout – the dropout probability.
channel_mult – channel multiplier for each level of the UNet.
conv_resample – if True, use learned convolutions for upsampling and downsampling.
dims – determines if the signal is 1D, 2D, or 3D.
num_classes – if specified (as an int), then this model will be class-conditional with num_classes classes.
use_checkpoint – use gradient checkpointing to reduce memory usage.
num_heads – the number of attention heads in each attention layer.
- __init__(in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, num_heads=1, num_heads_upsample=-1, use_scale_shift_norm=False)[source]¶
Methods
__init__
(in_channels, model_channels, ...[, ...])Convert the torso of the model to float16.
Convert the torso of the model to float32.
forward
(x, timesteps[, y])Apply the model to an input batch.
get_feature_vectors
(x, timesteps[, y])Apply the model and return all of the intermediate tensors.
Attributes
Get the dtype used by the torso of the model.
- forward(x, timesteps, y=None)[source]¶
Apply the model to an input batch.
- Parameters:
x – an [N x C x …] Tensor of inputs.
timesteps – a 1-D batch of timesteps.
y – an [N] Tensor of labels, if class-conditional.
- Returns:
an [N x C x …] Tensor of outputs.
- get_feature_vectors(x, timesteps, y=None)[source]¶
Apply the model and return all of the intermediate tensors.
- Parameters:
x – an [N x C x …] Tensor of inputs.
timesteps – a 1-D batch of timesteps.
y – an [N] Tensor of labels, if class-conditional.
- Returns:
a dict with the following keys: - ‘down’: a list of hidden state tensors from downsampling. - ‘middle’: the tensor of the output of the lowest-resolution
block in the model.
’up’: a list of hidden state tensors from upsampling.
- property inner_dtype¶
Get the dtype used by the torso of the model.