Source code for diffusion_models.spine_dataset.collations

import torch
from typing import List

[docs] def collate_fn(batch: List[dict]): res = {key: [] for key in batch[0].keys()} res["loss_fn"] = batch[0]["loss_fn"] for sample in batch: for key, elem in sample.items(): if isinstance(elem, torch.Tensor): res[key].append(elem) for key, elem in res.items(): if isinstance(elem, list): res[key] = torch.stack(res[key], dim=0) elif isinstance(elem, str): assert key == "loss_fn" else: raise ValueError(f"{key}") return res