SeparableConv2d

PHOTO EMBED

Sat Mar 12 2022 16:29:53 GMT+0000 (Coordinated Universal Time)

Saved by @mridulav #python #pytorch

import torch.nn as nn
from torch.nn.utils import spectral_norm

def conv2d(*args, **kwargs):
    return spectral_norm(nn.Conv2d(*args, **kwargs))

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=False):
        super(SeparableConv2d, self).__init__()
        self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
            groups=in_channels, bias=bias, padding=1)
        self.pointwise = conv2d(in_channels, out_channels,
            kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out
content_copyCOPY

https://github.com/autonomousvision/projected_gan/blob/main/pg_modules/blocks.py