diffusion_models.spine_dataset.spatial_transformer.SpatialTransformer

class diffusion_models.spine_dataset.spatial_transformer.SpatialTransformer(size, mode='bilinear')[source]

Bases: Module

N-D Spatial Transformer.

__init__(size, mode='bilinear')[source]

Initialize Spatial Transformer.

Parameters:
  • size (tuple) – Size tuple, shape

  • mode (str, optional) – Which mode to use in sampling. Defaults to “bilinear”.

Methods

__init__(size[, mode])

Initialize Spatial Transformer.

forward(src, flow)

Forward pass of the spatial transformer.

forward(src, flow)[source]

Forward pass of the spatial transformer. Generates the deformed image.

Parameters:
  • src (torch.Tensor) – image to deform.

  • flow (torch.Tensor) – blur kernel.

Returns:

deformed image.

Return type:

torch.Tensor