Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for diffusion_models.utils.helpers
from jaxtyping import Float , Complex
from torch import Tensor
import torch
[docs]
class dotdict ( dict ):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict . __getitem__
__setattr__ = dict . __setitem__
__delattr__ = dict . __delitem__
[docs]
def bytes_to_gb ( bytes : int ):
kb = bytes / 1024
mb = kb / 1024
gb = mb / 1024
return gb
[docs]
def complex_to_2channelfloat ( x : Complex [ Tensor , "*batch height width" ]) -> Float [ Tensor , "*batch 2 height width" ]:
x = torch . view_as_real ( x )
return x . permute ( * [ i for i in range ( x . dim () - 3 )], - 1 , - 3 , - 2 )