import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, hidden_layer, hidden_size): super(Net, self).__init__() self.name = 'model-%d-%d'%(hidden_layer, hidden_size) self.hidden_layer = hidden_layer self.conv_a = nn.Conv2d(1 , hidden_size, 3, padding=1) self.conv_x = nn.Conv2d(hidden_size, hidden_size, 3, padding=1) self.conv_z = nn.Conv2d(hidden_size, 2 , 3, padding=1) def forward(self, x): x = F.relu(self.conv_a(x)) for i in range(self.hidden_layer): x = F.relu(self.conv_x(x)) x = self.conv_z(x) return x