Source code for diffusion_models.mri_forward.fft
from torch.fft import fftn, ifftn, ifftshift, fftshift
from typing import Union
from jaxtyping import Float, Complex
from torch import Tensor
import torch
from diffusion_models.utils.helpers import complex_to_2channelfloat
[docs]
def to_kspace(
x: Union[
Float[Tensor, "*batch 2 height width"],
Complex[Tensor, "*batch height width"]
]
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
if torch.is_complex(x):
x = fftn(x, dim=(-2,-1))
return fftshift(x, dim=(-2,-1))
else:
x = x.permute(0,2,3,1).contiguous()
x = torch.view_as_complex(x)
x = fftn(x, dim=(-2,-1))
x = fftshift(x, dim=(-2,-1))
return complex_to_2channelfloat(x)
[docs]
def to_imgspace(
x: Union[
Float[Tensor, "*batch 2 height width"],
Complex[Tensor, "*batch height width"]
]
) -> Union[Float[Tensor, "*batch 2 height width"], Complex[Tensor, "*batch height width"]]:
if torch.is_complex(x):
x = ifftn(x, dim=(-2,-1))
return ifftshift(x, dim=(-2,-1))
else:
x = x.permute(0,2,3,1).contiguous()
x = torch.view_as_complex(x)
x = ifftn(x, dim=(-2,-1))
x = ifftshift(x, dim=(-2,-1))
return complex_to_2channelfloat(x)