summaryrefslogtreecommitdiff
path: root/Codes/unet.py
blob: ac4c6aa9da6d53fe3f4693210afcb1d7f5964056 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
from tensorflow.contrib.layers import conv2d, max_pool2d, conv2d_transpose


def unet(inputs, layers, features_root=64, filter_size=3, pool_size=2, output_channel=1):
    """
    :param inputs: input tensor, shape[None, height, width, channel]
    :param layers: number of layers
    :param features_root: number of features in the first layer
    :param filter_size: size of each conv layer
    :param pool_size:  size of each max pooling layer
    :param output_channel:  number of channel for output tensor
    :return: a tensor, shape[None, height, width, output_channel]
    """

    in_node = inputs
    conv = []
    for layer in range(0, layers):
        features = 2**layer*features_root

        conv1 = conv2d(inputs=in_node, num_outputs=features, kernel_size=filter_size)
        conv2 = conv2d(inputs=conv1, num_outputs=features, kernel_size=filter_size)
        conv.append(conv2)

        if layer < layers - 1:
            in_node = max_pool2d(inputs=conv2, kernel_size=pool_size, padding='SAME')
            # in_node = conv2d(inputs=conv2, num_outputs=features, kernel_size=filter_size, stride=2)

    in_node = conv[-1]

    for layer in range(layers-2, -1, -1):
        features = 2**(layer+1)*features_root

        h_deconv = conv2d_transpose(inputs=in_node, num_outputs=features//2, kernel_size=pool_size, stride=pool_size)
        h_deconv_concat = tf.concat([conv[layer], h_deconv], axis=3)

        conv1 = conv2d(inputs=h_deconv_concat, num_outputs=features//2, kernel_size=filter_size)
        in_node = conv2d(inputs=conv1, num_outputs=features//2, kernel_size=filter_size)

    output = conv2d(inputs=in_node, num_outputs=output_channel, kernel_size=filter_size, activation_fn=None)
    output = tf.tanh(output)
    return output