import torch import torch.nn as nn import torch.nn.functional as F class STN(nn.Module): def __init__(self): super(STN, self).__init__() # simple convnet classifier self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) # spatial transformer localization network self.localization = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(True), nn.Conv2d(64, 128, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU(True) ) # tranformation regressor for theta self.fc_loc = nn.Sequential( nn.Linear(128*4*4, 256), nn.ReLU(True), nn.Linear(256, 3 * 2) ) # initializing the weights and biases with identity transformations self.fc_loc[2].weight.data.zero_() self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) def stn(self, x): xs = self.localization(x) xs = xs.view(-1, xs.size(1)*xs.size(2)*xs.size(3)) # calculate the transformation parameters theta theta = self.fc_loc(xs) # resize theta theta = theta.view(-1, 2, 3) # grid generator => transformation on parameters theta grid = F.affine_grid(theta, x.size()) # grid sampling => applying the spatial transformations x = F.grid_sample(x, grid) return x def forward(self, x): # transform the input x = self.stn(x) # forward pass through the classifier x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return F.log_softmax(x, dim=1)
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter