summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--network.py196
-rw-r--r--run.py246
2 files changed, 243 insertions, 199 deletions
diff --git a/network.py b/network.py
new file mode 100644
index 0000000..6915731
--- /dev/null
+++ b/network.py
@@ -0,0 +1,196 @@
+import math
+import torch
+import torch.utils.serialization
+
+from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer
+
+torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them
+
+torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
+
+class Network(torch.nn.Module):
+ def __init__(self, model_name):
+ super(Network, self).__init__()
+
+ def Basic(intInput, intOutput):
+ return torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ # end
+
+ def Subnet():
+ return torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1)
+ )
+ # end
+
+ self.moduleConv1 = Basic(6, 32)
+ self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv2 = Basic(32, 64)
+ self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv3 = Basic(64, 128)
+ self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv4 = Basic(128, 256)
+ self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv5 = Basic(256, 512)
+ self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleDeconv5 = Basic(512, 512)
+ self.moduleUpsample5 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv4 = Basic(512, 256)
+ self.moduleUpsample4 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv3 = Basic(256, 128)
+ self.moduleUpsample3 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv2 = Basic(128, 64)
+ self.moduleUpsample2 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleVertical1 = Subnet()
+ self.moduleVertical2 = Subnet()
+ self.moduleHorizontal1 = Subnet()
+ self.moduleHorizontal2 = Subnet()
+
+ self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ])
+
+ self.load_state_dict(torch.load('./network-' + model_name + '.pytorch'))
+ # end
+
+ def forward(self, variableInput1, variableInput2):
+ variableJoin = torch.cat([variableInput1, variableInput2], 1)
+
+ variableConv1 = self.moduleConv1(variableJoin)
+ variablePool1 = self.modulePool1(variableConv1)
+
+ variableConv2 = self.moduleConv2(variablePool1)
+ variablePool2 = self.modulePool2(variableConv2)
+
+ variableConv3 = self.moduleConv3(variablePool2)
+ variablePool3 = self.modulePool3(variableConv3)
+
+ variableConv4 = self.moduleConv4(variablePool3)
+ variablePool4 = self.modulePool4(variableConv4)
+
+ variableConv5 = self.moduleConv5(variablePool4)
+ variablePool5 = self.modulePool5(variableConv5)
+
+ variableDeconv5 = self.moduleDeconv5(variablePool5)
+ variableUpsample5 = self.moduleUpsample5(variableDeconv5)
+
+ variableCombine = variableUpsample5 + variableConv5
+
+ variableDeconv4 = self.moduleDeconv4(variableCombine)
+ variableUpsample4 = self.moduleUpsample4(variableDeconv4)
+
+ variableCombine = variableUpsample4 + variableConv4
+
+ variableDeconv3 = self.moduleDeconv3(variableCombine)
+ variableUpsample3 = self.moduleUpsample3(variableDeconv3)
+
+ variableCombine = variableUpsample3 + variableConv3
+
+ variableDeconv2 = self.moduleDeconv2(variableCombine)
+ variableUpsample2 = self.moduleUpsample2(variableDeconv2)
+
+ variableCombine = variableUpsample2 + variableConv2
+
+ variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine))
+ variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine))
+
+ return variableDot1 + variableDot2
+ # end
+# end
+
+##########################################################
+
+def process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput):
+ assert(tensorInputFirst.size(1) == tensorInputSecond.size(1))
+ assert(tensorInputFirst.size(2) == tensorInputSecond.size(2))
+
+ intWidth = tensorInputFirst.size(2)
+ intHeight = tensorInputFirst.size(1)
+
+ assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
+ assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
+
+ intPaddingLeft = int(math.floor(51 / 2.0))
+ intPaddingTop = int(math.floor(51 / 2.0))
+ intPaddingRight = int(math.floor(51 / 2.0))
+ intPaddingBottom = int(math.floor(51 / 2.0))
+ modulePaddingInput = torch.nn.Module()
+ modulePaddingOutput = torch.nn.Module()
+
+ if True:
+ intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight
+ intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom
+
+ if intPaddingWidth != ((intPaddingWidth >> 7) << 7):
+ intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary
+ # end
+
+ if intPaddingHeight != ((intPaddingHeight >> 7) << 7):
+ intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary
+ # end
+
+ intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight)
+ intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom)
+
+ modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight])
+ modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight])
+ # end
+
+ if True:
+ tensorInputFirst = tensorInputFirst.cuda()
+ tensorInputSecond = tensorInputSecond.cuda()
+
+ modulePaddingInput = modulePaddingInput.cuda()
+ modulePaddingOutput = modulePaddingOutput.cuda()
+ # end
+
+ if True:
+ variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True))
+ variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True))
+ variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond))
+
+ tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0])
+ # end
+
+ if True:
+ tensorInputFirst.cpu()
+ tensorInputSecond.cpu()
+ tensorOutput.cpu()
+ # end
+#end
diff --git a/run.py b/run.py
index 0020c08..bbbeafc 100644
--- a/run.py
+++ b/run.py
@@ -2,18 +2,12 @@
import sys
import getopt
-import math
import numpy
import torch
-import torch.utils.serialization
import PIL
import PIL.Image
-
-from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer
-
-torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them
-
-torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
+from datetime import datetime
+from network import Network, process
##########################################################
@@ -21,6 +15,8 @@ arguments_strModel = 'lf'
arguments_strFirst = './images/first.png'
arguments_strSecond = './images/second.png'
arguments_strOut = './result.png'
+arguments_strVideoOut = datetime.now().strftime("sepconv_%Y%m%d_%H%M.mp4")
+arguments_steps = 0
for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
if strOption == '--model':
@@ -35,201 +31,53 @@ for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:]
elif strOption == '--out':
arguments_strOut = strArgument # path to where the output should be stored
- # end
-# end
-
-##########################################################
-
-class Network(torch.nn.Module):
- def __init__(self):
- super(Network, self).__init__()
-
- def Basic(intInput, intOutput):
- return torch.nn.Sequential(
- torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False),
- torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False),
- torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False)
- )
- # end
-
- def Subnet():
- return torch.nn.Sequential(
- torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False),
- torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False),
- torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False),
- torch.nn.Upsample(scale_factor=2, mode='bilinear'),
- torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1)
- )
- # end
-
- self.moduleConv1 = Basic(6, 32)
- self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
-
- self.moduleConv2 = Basic(32, 64)
- self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
-
- self.moduleConv3 = Basic(64, 128)
- self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
-
- self.moduleConv4 = Basic(128, 256)
- self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
-
- self.moduleConv5 = Basic(256, 512)
- self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+ elif strOption == '--video-out':
+ arguments_strVideoOut = strArgument # path to where the video should be stored
- self.moduleDeconv5 = Basic(512, 512)
- self.moduleUpsample5 = torch.nn.Sequential(
- torch.nn.Upsample(scale_factor=2, mode='bilinear'),
- torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False)
- )
-
- self.moduleDeconv4 = Basic(512, 256)
- self.moduleUpsample4 = torch.nn.Sequential(
- torch.nn.Upsample(scale_factor=2, mode='bilinear'),
- torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False)
- )
-
- self.moduleDeconv3 = Basic(256, 128)
- self.moduleUpsample3 = torch.nn.Sequential(
- torch.nn.Upsample(scale_factor=2, mode='bilinear'),
- torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False)
- )
-
- self.moduleDeconv2 = Basic(128, 64)
- self.moduleUpsample2 = torch.nn.Sequential(
- torch.nn.Upsample(scale_factor=2, mode='bilinear'),
- torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
- torch.nn.ReLU(inplace=False)
- )
-
- self.moduleVertical1 = Subnet()
- self.moduleVertical2 = Subnet()
- self.moduleHorizontal1 = Subnet()
- self.moduleHorizontal2 = Subnet()
-
- self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ])
-
- self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch'))
- # end
-
- def forward(self, variableInput1, variableInput2):
- variableJoin = torch.cat([variableInput1, variableInput2], 1)
-
- variableConv1 = self.moduleConv1(variableJoin)
- variablePool1 = self.modulePool1(variableConv1)
-
- variableConv2 = self.moduleConv2(variablePool1)
- variablePool2 = self.modulePool2(variableConv2)
-
- variableConv3 = self.moduleConv3(variablePool2)
- variablePool3 = self.modulePool3(variableConv3)
-
- variableConv4 = self.moduleConv4(variablePool3)
- variablePool4 = self.modulePool4(variableConv4)
-
- variableConv5 = self.moduleConv5(variablePool4)
- variablePool5 = self.modulePool5(variableConv5)
-
- variableDeconv5 = self.moduleDeconv5(variablePool5)
- variableUpsample5 = self.moduleUpsample5(variableDeconv5)
-
- variableCombine = variableUpsample5 + variableConv5
-
- variableDeconv4 = self.moduleDeconv4(variableCombine)
- variableUpsample4 = self.moduleUpsample4(variableDeconv4)
-
- variableCombine = variableUpsample4 + variableConv4
-
- variableDeconv3 = self.moduleDeconv3(variableCombine)
- variableUpsample3 = self.moduleUpsample3(variableDeconv3)
-
- variableCombine = variableUpsample3 + variableConv3
-
- variableDeconv2 = self.moduleDeconv2(variableCombine)
- variableUpsample2 = self.moduleUpsample2(variableDeconv2)
-
- variableCombine = variableUpsample2 + variableConv2
-
- variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine))
- variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine))
-
- return variableDot1 + variableDot2
+ elif strOption == '--steps':
+ arguments_steps = int(strArgument)
# end
# end
-moduleNetwork = Network().cuda()
-
-##########################################################
-
-tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:, :, ::-1], 2, 0).astype(numpy.float32) / 255.0)
-tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:, :, ::-1], 2, 0).astype(numpy.float32) / 255.0)
+moduleNetwork = Network(arguments_strModel).cuda()
tensorOutput = torch.FloatTensor()
-assert(tensorInputFirst.size(1) == tensorInputSecond.size(1))
-assert(tensorInputFirst.size(2) == tensorInputSecond.size(2))
-
-intWidth = tensorInputFirst.size(2)
-intHeight = tensorInputFirst.size(1)
-
-assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
-assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
-
-intPaddingLeft = int(math.floor(51 / 2.0))
-intPaddingTop = int(math.floor(51 / 2.0))
-intPaddingRight = int(math.floor(51 / 2.0))
-intPaddingBottom = int(math.floor(51 / 2.0))
-modulePaddingInput = torch.nn.Sequential()
-modulePaddingOutput = torch.nn.Sequential()
-
-if True:
- intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight
- intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom
-
- if intPaddingWidth != ((intPaddingWidth >> 7) << 7):
- intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary
- # end
-
- if intPaddingHeight != ((intPaddingHeight >> 7) << 7):
- intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary
- # end
-
- intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight)
- intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom)
-
- modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight])
- modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight])
-# end
-
-if True:
- tensorInputFirst = tensorInputFirst.cuda()
- tensorInputSecond = tensorInputSecond.cuda()
- tensorOutput = tensorOutput.cuda()
-
- modulePaddingInput = modulePaddingInput.cuda()
- modulePaddingOutput = modulePaddingOutput.cuda()
-# end
-
-if True:
- variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True))
- variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True))
- variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond))
-
- tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0])
-# end
-
-if True:
- tensorInputFirst = tensorInputFirst.cpu()
- tensorInputSecond = tensorInputSecond.cpu()
- tensorOutput = tensorOutput.cpu()
-# end
+def process_tree(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput, steps):
+ process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput)
+ tensorMiddle = (numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)
+ if steps < 2:
+ return [tensorMiddle]
+ else:
+ tensorLeft = process_tree(moduleNetwork, tensorInputFirst, tensorMiddle, tensorOutput, steps / 2)
+ tensorRight = process_tree(moduleNetwork, tensorMiddle, tensorInputSecond, tensorOutput, steps / 2)
+ return tensorLeft + [tensorMiddle] + tensorRight
-PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:, :, ::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut) \ No newline at end of file
+if arguments_strVideo and arguments_strVideoOut:
+ reader = FFMPEG_VideoReader(arguments_strVideo, False)
+ writer = FFMPEG_VideoWriter(arguments_strVideoOut, reader.size, reader.fps*2)
+ reader.initialize()
+ nextFrame = reader.read_frame()
+ for x in range(0, reader.nframes):
+ firstFrame = nextFrame
+ nextFrame = reader.read_frame()
+ tensorInputFirst = torch.FloatTensor(numpy.rollaxis(firstFrame[:,:,::-1], 2, 0) / 255.0)
+ tensorInputSecond = torch.FloatTensor(numpy.rollaxis(nextFrame[:,:,::-1], 2, 0) / 255.0)
+ process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput)
+ writer.write_frame(firstFrame)
+ writer.write_frame((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8))
+ writer.write_frame(nextFrame)
+ writer.close()
+else:
+ # Process image
+ tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0)
+ tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0)
+ if arguments_steps == 0:
+ process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput)
+ PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut)
+ else:
+ tree = process_tree(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput, arguments_steps)
+ writer = FFMPEG_VideoWriter(arguments_strVideoOut, (1024, 512), 25)
+ writer.write_frame(tensorInputFirst)
+ for frame in tree:
+ writer.write_frame(frame)
+ writer.write_frame(tensorInputSecond)