summaryrefslogtreecommitdiff
path: root/nets.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-11-18 15:03:27 +0100
committerJules Laplace <julescarbon@gmail.com>2018-11-18 15:03:27 +0100
commit0d2314ef1ce689a8281f89ffd1bcfc3a677cc3cd (patch)
treecf59fc211cbbfe685ac55505cd3eda58a71064ce /nets.py
just code
Diffstat (limited to 'nets.py')
-rw-r--r--nets.py150
1 files changed, 150 insertions, 0 deletions
diff --git a/nets.py b/nets.py
new file mode 100644
index 0000000..70ea3e9
--- /dev/null
+++ b/nets.py
@@ -0,0 +1,150 @@
+## -*- coding: utf-8 -*-
+import tensorflow as tf
+
+from utils import BatchNorm, Conv3D
+
+stp = [[0,0], [1,1], [1,1], [1,1], [0,0]]
+sp = [[0,0], [0,0], [1,1], [1,1], [0,0]]
+
+def FR_16L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 32
+ for r in range(3):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(3,6):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r
+
+def FR_28L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 16
+ for r in range(9):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(9,12):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
+
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r
+
+def FR_52L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 16
+ for r in range(0,21):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(21,24):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,448,256], [1,1,1,1,1], 'VALID', name='conv2')
+
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r \ No newline at end of file