diff options
Diffstat (limited to 'network.py')
| -rw-r--r-- | network.py | 196 |
1 files changed, 196 insertions, 0 deletions
diff --git a/network.py b/network.py new file mode 100644 index 0000000..6915731 --- /dev/null +++ b/network.py @@ -0,0 +1,196 @@ +import math +import torch +import torch.utils.serialization + +from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer + +torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them + +torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + +class Network(torch.nn.Module): + def __init__(self, model_name): + super(Network, self).__init__() + + def Basic(intInput, intOutput): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + # end + + def Subnet(): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1) + ) + # end + + self.moduleConv1 = Basic(6, 32) + self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv2 = Basic(32, 64) + self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv3 = Basic(64, 128) + self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv4 = Basic(128, 256) + self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv5 = Basic(256, 512) + self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleDeconv5 = Basic(512, 512) + self.moduleUpsample5 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv4 = Basic(512, 256) + self.moduleUpsample4 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv3 = Basic(256, 128) + self.moduleUpsample3 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv2 = Basic(128, 64) + self.moduleUpsample2 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleVertical1 = Subnet() + self.moduleVertical2 = Subnet() + self.moduleHorizontal1 = Subnet() + self.moduleHorizontal2 = Subnet() + + self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ]) + + self.load_state_dict(torch.load('./network-' + model_name + '.pytorch')) + # end + + def forward(self, variableInput1, variableInput2): + variableJoin = torch.cat([variableInput1, variableInput2], 1) + + variableConv1 = self.moduleConv1(variableJoin) + variablePool1 = self.modulePool1(variableConv1) + + variableConv2 = self.moduleConv2(variablePool1) + variablePool2 = self.modulePool2(variableConv2) + + variableConv3 = self.moduleConv3(variablePool2) + variablePool3 = self.modulePool3(variableConv3) + + variableConv4 = self.moduleConv4(variablePool3) + variablePool4 = self.modulePool4(variableConv4) + + variableConv5 = self.moduleConv5(variablePool4) + variablePool5 = self.modulePool5(variableConv5) + + variableDeconv5 = self.moduleDeconv5(variablePool5) + variableUpsample5 = self.moduleUpsample5(variableDeconv5) + + variableCombine = variableUpsample5 + variableConv5 + + variableDeconv4 = self.moduleDeconv4(variableCombine) + variableUpsample4 = self.moduleUpsample4(variableDeconv4) + + variableCombine = variableUpsample4 + variableConv4 + + variableDeconv3 = self.moduleDeconv3(variableCombine) + variableUpsample3 = self.moduleUpsample3(variableDeconv3) + + variableCombine = variableUpsample3 + variableConv3 + + variableDeconv2 = self.moduleDeconv2(variableCombine) + variableUpsample2 = self.moduleUpsample2(variableDeconv2) + + variableCombine = variableUpsample2 + variableConv2 + + variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine)) + variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine)) + + return variableDot1 + variableDot2 + # end +# end + +########################################################## + +def process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput): + assert(tensorInputFirst.size(1) == tensorInputSecond.size(1)) + assert(tensorInputFirst.size(2) == tensorInputSecond.size(2)) + + intWidth = tensorInputFirst.size(2) + intHeight = tensorInputFirst.size(1) + + assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications + assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications + + intPaddingLeft = int(math.floor(51 / 2.0)) + intPaddingTop = int(math.floor(51 / 2.0)) + intPaddingRight = int(math.floor(51 / 2.0)) + intPaddingBottom = int(math.floor(51 / 2.0)) + modulePaddingInput = torch.nn.Module() + modulePaddingOutput = torch.nn.Module() + + if True: + intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight + intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom + + if intPaddingWidth != ((intPaddingWidth >> 7) << 7): + intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary + # end + + if intPaddingHeight != ((intPaddingHeight >> 7) << 7): + intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary + # end + + intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight) + intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom) + + modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) + modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight]) + # end + + if True: + tensorInputFirst = tensorInputFirst.cuda() + tensorInputSecond = tensorInputSecond.cuda() + + modulePaddingInput = modulePaddingInput.cuda() + modulePaddingOutput = modulePaddingOutput.cuda() + # end + + if True: + variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond)) + + tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0]) + # end + + if True: + tensorInputFirst.cpu() + tensorInputSecond.cpu() + tensorOutput.cpu() + # end +#end |
