diffusion_models.models.openai_unet.SuperResModel¶
- class diffusion_models.models.openai_unet.SuperResModel(*args: Any, **kwargs: Any)[source]¶
Bases:
UNetModel
A UNetModel that performs super-resolution.
Expects an extra kwarg low_res to condition on a low-resolution image.
Methods
__init__
(in_channels, *args, **kwargs)Convert the torso of the model to float16.
Convert the torso of the model to float32.
forward
(x, timesteps[, low_res])Apply the model to an input batch.
get_feature_vectors
(x, timesteps[, low_res])Apply the model and return all of the intermediate tensors.
Attributes
Get the dtype used by the torso of the model.
- convert_to_fp16()¶
Convert the torso of the model to float16.
- convert_to_fp32()¶
Convert the torso of the model to float32.
- forward(x, timesteps, low_res=None, **kwargs)[source]¶
Apply the model to an input batch.
- Parameters:
x – an [N x C x …] Tensor of inputs.
timesteps – a 1-D batch of timesteps.
y – an [N] Tensor of labels, if class-conditional.
- Returns:
an [N x C x …] Tensor of outputs.
- get_feature_vectors(x, timesteps, low_res=None, **kwargs)[source]¶
Apply the model and return all of the intermediate tensors.
- Parameters:
x – an [N x C x …] Tensor of inputs.
timesteps – a 1-D batch of timesteps.
y – an [N] Tensor of labels, if class-conditional.
- Returns:
a dict with the following keys: - ‘down’: a list of hidden state tensors from downsampling. - ‘middle’: the tensor of the output of the lowest-resolution
block in the model.
’up’: a list of hidden state tensors from upsampling.
- property inner_dtype¶
Get the dtype used by the torso of the model.