Source code for diffusion_models.utils.mp_setup
import os
from typing import Any
import torch
from torch.distributed import init_process_group, destroy_process_group
[docs]
class DDP_Proc_Group:
[docs]
def __init__(self, function) -> None:
self.function = function
def __call__(self, *args, **kwargs) -> None:
self._ddp_setup(args[0], args[1])
self.function(*args, **kwargs)
destroy_process_group()
def _ddp_setup(self, rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)