import torch.nn as nn from torch.nn.utils import spectral_norm def conv2d(*args, **kwargs): return spectral_norm(nn.Conv2d(*args, **kwargs)) class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, bias=False): super(SeparableConv2d, self).__init__() self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, bias=bias, padding=1) self.pointwise = conv2d(in_channels, out_channels, kernel_size=1, bias=bias) def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out