diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-06-25 15:27:50 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-06-25 15:27:50 +0200 |
| commit | cb25a532a89f4d971193038bb7dce1e4436381f9 (patch) | |
| tree | 918a73028e9879b97035e5e49165b562c9087a81 /run.py | |
| parent | 9cdabbe592c822586b075c6b7659de995613025e (diff) | |
split out network, do a morph dewd
Diffstat (limited to 'run.py')
| -rw-r--r-- | run.py | 246 |
1 files changed, 47 insertions, 199 deletions
@@ -2,18 +2,12 @@ import sys import getopt -import math import numpy import torch -import torch.utils.serialization import PIL import PIL.Image - -from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer - -torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them - -torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance +from datetime import datetime +from network import Network, process ########################################################## @@ -21,6 +15,8 @@ arguments_strModel = 'lf' arguments_strFirst = './images/first.png' arguments_strSecond = './images/second.png' arguments_strOut = './result.png' +arguments_strVideoOut = datetime.now().strftime("sepconv_%Y%m%d_%H%M.mp4") +arguments_steps = 0 for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: if strOption == '--model': @@ -35,201 +31,53 @@ for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] elif strOption == '--out': arguments_strOut = strArgument # path to where the output should be stored - # end -# end - -########################################################## - -class Network(torch.nn.Module): - def __init__(self): - super(Network, self).__init__() - - def Basic(intInput, intOutput): - return torch.nn.Sequential( - torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False), - torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False), - torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False) - ) - # end - - def Subnet(): - return torch.nn.Sequential( - torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False), - torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False), - torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False), - torch.nn.Upsample(scale_factor=2, mode='bilinear'), - torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1) - ) - # end - - self.moduleConv1 = Basic(6, 32) - self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2) - - self.moduleConv2 = Basic(32, 64) - self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2) - - self.moduleConv3 = Basic(64, 128) - self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2) - - self.moduleConv4 = Basic(128, 256) - self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2) - - self.moduleConv5 = Basic(256, 512) - self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + elif strOption == '--video-out': + arguments_strVideoOut = strArgument # path to where the video should be stored - self.moduleDeconv5 = Basic(512, 512) - self.moduleUpsample5 = torch.nn.Sequential( - torch.nn.Upsample(scale_factor=2, mode='bilinear'), - torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False) - ) - - self.moduleDeconv4 = Basic(512, 256) - self.moduleUpsample4 = torch.nn.Sequential( - torch.nn.Upsample(scale_factor=2, mode='bilinear'), - torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False) - ) - - self.moduleDeconv3 = Basic(256, 128) - self.moduleUpsample3 = torch.nn.Sequential( - torch.nn.Upsample(scale_factor=2, mode='bilinear'), - torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False) - ) - - self.moduleDeconv2 = Basic(128, 64) - self.moduleUpsample2 = torch.nn.Sequential( - torch.nn.Upsample(scale_factor=2, mode='bilinear'), - torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=False) - ) - - self.moduleVertical1 = Subnet() - self.moduleVertical2 = Subnet() - self.moduleHorizontal1 = Subnet() - self.moduleHorizontal2 = Subnet() - - self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ]) - - self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch')) - # end - - def forward(self, variableInput1, variableInput2): - variableJoin = torch.cat([variableInput1, variableInput2], 1) - - variableConv1 = self.moduleConv1(variableJoin) - variablePool1 = self.modulePool1(variableConv1) - - variableConv2 = self.moduleConv2(variablePool1) - variablePool2 = self.modulePool2(variableConv2) - - variableConv3 = self.moduleConv3(variablePool2) - variablePool3 = self.modulePool3(variableConv3) - - variableConv4 = self.moduleConv4(variablePool3) - variablePool4 = self.modulePool4(variableConv4) - - variableConv5 = self.moduleConv5(variablePool4) - variablePool5 = self.modulePool5(variableConv5) - - variableDeconv5 = self.moduleDeconv5(variablePool5) - variableUpsample5 = self.moduleUpsample5(variableDeconv5) - - variableCombine = variableUpsample5 + variableConv5 - - variableDeconv4 = self.moduleDeconv4(variableCombine) - variableUpsample4 = self.moduleUpsample4(variableDeconv4) - - variableCombine = variableUpsample4 + variableConv4 - - variableDeconv3 = self.moduleDeconv3(variableCombine) - variableUpsample3 = self.moduleUpsample3(variableDeconv3) - - variableCombine = variableUpsample3 + variableConv3 - - variableDeconv2 = self.moduleDeconv2(variableCombine) - variableUpsample2 = self.moduleUpsample2(variableDeconv2) - - variableCombine = variableUpsample2 + variableConv2 - - variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine)) - variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine)) - - return variableDot1 + variableDot2 + elif strOption == '--steps': + arguments_steps = int(strArgument) # end # end -moduleNetwork = Network().cuda() - -########################################################## - -tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:, :, ::-1], 2, 0).astype(numpy.float32) / 255.0) -tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:, :, ::-1], 2, 0).astype(numpy.float32) / 255.0) +moduleNetwork = Network(arguments_strModel).cuda() tensorOutput = torch.FloatTensor() -assert(tensorInputFirst.size(1) == tensorInputSecond.size(1)) -assert(tensorInputFirst.size(2) == tensorInputSecond.size(2)) - -intWidth = tensorInputFirst.size(2) -intHeight = tensorInputFirst.size(1) - -assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications -assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications - -intPaddingLeft = int(math.floor(51 / 2.0)) -intPaddingTop = int(math.floor(51 / 2.0)) -intPaddingRight = int(math.floor(51 / 2.0)) -intPaddingBottom = int(math.floor(51 / 2.0)) -modulePaddingInput = torch.nn.Sequential() -modulePaddingOutput = torch.nn.Sequential() - -if True: - intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight - intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom - - if intPaddingWidth != ((intPaddingWidth >> 7) << 7): - intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary - # end - - if intPaddingHeight != ((intPaddingHeight >> 7) << 7): - intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary - # end - - intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight) - intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom) - - modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) - modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight]) -# end - -if True: - tensorInputFirst = tensorInputFirst.cuda() - tensorInputSecond = tensorInputSecond.cuda() - tensorOutput = tensorOutput.cuda() - - modulePaddingInput = modulePaddingInput.cuda() - modulePaddingOutput = modulePaddingOutput.cuda() -# end - -if True: - variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True)) - variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True)) - variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond)) - - tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0]) -# end - -if True: - tensorInputFirst = tensorInputFirst.cpu() - tensorInputSecond = tensorInputSecond.cpu() - tensorOutput = tensorOutput.cpu() -# end +def process_tree(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput, steps): + process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput) + tensorMiddle = (numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8) + if steps < 2: + return [tensorMiddle] + else: + tensorLeft = process_tree(moduleNetwork, tensorInputFirst, tensorMiddle, tensorOutput, steps / 2) + tensorRight = process_tree(moduleNetwork, tensorMiddle, tensorInputSecond, tensorOutput, steps / 2) + return tensorLeft + [tensorMiddle] + tensorRight -PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:, :, ::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut)
\ No newline at end of file +if arguments_strVideo and arguments_strVideoOut: + reader = FFMPEG_VideoReader(arguments_strVideo, False) + writer = FFMPEG_VideoWriter(arguments_strVideoOut, reader.size, reader.fps*2) + reader.initialize() + nextFrame = reader.read_frame() + for x in range(0, reader.nframes): + firstFrame = nextFrame + nextFrame = reader.read_frame() + tensorInputFirst = torch.FloatTensor(numpy.rollaxis(firstFrame[:,:,::-1], 2, 0) / 255.0) + tensorInputSecond = torch.FloatTensor(numpy.rollaxis(nextFrame[:,:,::-1], 2, 0) / 255.0) + process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput) + writer.write_frame(firstFrame) + writer.write_frame((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)) + writer.write_frame(nextFrame) + writer.close() +else: + # Process image + tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) + tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) + if arguments_steps == 0: + process(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput) + PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut) + else: + tree = process_tree(moduleNetwork, tensorInputFirst, tensorInputSecond, tensorOutput, arguments_steps) + writer = FFMPEG_VideoWriter(arguments_strVideoOut, (1024, 512), 25) + writer.write_frame(tensorInputFirst) + for frame in tree: + writer.write_frame(frame) + writer.write_frame(tensorInputSecond) |
