Source code for diffusion_models.spine_dataset.spatial_transformer

"""Class and helper functions used for the random deformations."""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as nnf


[docs] def gauss_gen3D(n=10, s=5, sigma=3) -> torch.Tensor: """Generate blur kernel. Args: n (int, optional): mean of gaussian. Defaults to 10. s (int, optional): steps where value is defined. Defaults to 5. sigma (int, optional): sigma of the gaussian. Defaults to 3. Returns: torch.Tensor: blur_kernel """ sigma = 3 x = np.linspace(-(n - 1) / 2, (n - 1) / 2, s) y = np.linspace(-(n - 1) / 2, (n - 1) / 2, s) z = np.linspace(-(n - 1) / 2, (n - 1) / 2, s) xv, yv, zv = np.meshgrid(x, y, z) hg = np.exp(-(xv**2 + yv**2 + zv**2) / (2 * sigma**2)) h = hg / np.sum(hg) blur_kernel = torch.from_numpy(h)[None, None, :] return blur_kernel
[docs] class SpatialTransformer(nn.Module): """N-D Spatial Transformer."""
[docs] def __init__(self, size: tuple, mode="bilinear") -> None: """Initialize Spatial Transformer. Args: size (tuple): Size tuple, shape mode (str, optional): Which mode to use in sampling. Defaults to "bilinear". """ super().__init__() self.mode = mode # create sampling grid vectors = [torch.arange(0, s) for s in size] grids = torch.meshgrid(vectors, indexing="ij") grid = torch.stack(grids) grid = torch.unsqueeze(grid, 0) grid = grid.type(torch.float) # registering the grid as a buffer cleanly moves it to the GPU, but it also # adds it to the state dict. this is annoying since everything in the state dict # is included when saving weights to disk, so the model files are way bigger # than they need to be. so far, there does not appear to be an elegant solution. # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict self.register_buffer("grid", grid)
[docs] def forward(self, src: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: """Forward pass of the spatial transformer. Generates the deformed image. Args: src (torch.Tensor): image to deform. flow (torch.Tensor): blur kernel. Returns: torch.Tensor: deformed image. """ # new locations # print("This is the self grid shape: ", self.grid.shape) new_locs = self.grid.to(flow.device) + flow shape = flow.shape[2:] # need to normalize grid values to [-1, 1] for resampler for i in range(len(shape)): new_locs[:, i, ...] = 2 * ( new_locs[:, i, ...] / (shape[i] - 1) - 0.5 ) # move channels dim to last position # also not sure why, but the channels need to be reversed if len(shape) == 2: new_locs = new_locs.permute(0, 2, 3, 1) new_locs = new_locs[..., [1, 0]] elif len(shape) == 3: new_locs = new_locs.permute(0, 2, 3, 4, 1) new_locs = new_locs[..., [2, 1, 0]] return nnf.grid_sample( src, new_locs, align_corners=True, mode=self.mode )