Source code for diffusion_models.models.mnist_enc

import torch.nn as nn
import torch

[docs] class MNISTEncoder(nn.Module):
[docs] def __init__(self, out_classes=10, kernel_size=3): super().__init__() self.kernel_size = kernel_size self.out_classes = out_classes channels = [2**i for i in range(5)] self.encoder = [] for i in range(4): self.encoder.append(nn.Conv2d(channels[i], channels[i+1], kernel_size=kernel_size, padding="same")) self.encoder.append(nn.BatchNorm2d(channels[i+1])) self.encoder.append(nn.ReLU()) self.encoder.append(nn.MaxPool2d(2)) self.conv = nn.Sequential(*self.encoder) self.fc = nn.Linear(16, self.out_classes)
def forward(self, x): x = self.conv(x) return self.fc(x.squeeze())