Source code for diffusion_models.spine_dataset.spatial_transformer
"""Class and helper functions used for the random deformations."""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnf
[docs]
def gauss_gen3D(n=10, s=5, sigma=3) -> torch.Tensor:
    """Generate blur kernel.
    Args:
        n (int, optional): mean of gaussian. Defaults to 10.
        s (int, optional): steps where value is defined. Defaults to 5.
        sigma (int, optional): sigma of the gaussian. Defaults to 3.
    Returns:
        torch.Tensor: blur_kernel
    """
    sigma = 3
    x = np.linspace(-(n - 1) / 2, (n - 1) / 2, s)
    y = np.linspace(-(n - 1) / 2, (n - 1) / 2, s)
    z = np.linspace(-(n - 1) / 2, (n - 1) / 2, s)
    xv, yv, zv = np.meshgrid(x, y, z)
    hg = np.exp(-(xv**2 + yv**2 + zv**2) / (2 * sigma**2))
    h = hg / np.sum(hg)
    blur_kernel = torch.from_numpy(h)[None, None, :]
    return blur_kernel
[docs]
class SpatialTransformer(nn.Module):
    """N-D Spatial Transformer."""
[docs]
    def __init__(self, size: tuple, mode="bilinear") -> None:
        """Initialize Spatial Transformer.
        Args:
            size (tuple): Size tuple, shape
            mode (str, optional): Which mode to use in sampling. Defaults to "bilinear".
        """
        super().__init__()
        self.mode = mode
        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors, indexing="ij")
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.float)
        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer("grid", grid)
[docs]
    def forward(self, src: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
        """Forward pass of the spatial transformer. Generates the deformed image.
        Args:
            src (torch.Tensor): image to deform.
            flow (torch.Tensor): blur kernel.
        Returns:
            torch.Tensor: deformed image.
        """
        # new locations
        # print("This is the self grid shape: ", self.grid.shape)
        new_locs = self.grid.to(flow.device) + flow
        shape = flow.shape[2:]
        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (
                new_locs[:, i, ...] / (shape[i] - 1) - 0.5
            )
        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]
        return nnf.grid_sample(
            src, new_locs, align_corners=True, mode=self.mode
        )