summaryrefslogtreecommitdiff
path: root/Codes/constant.py
diff options
context:
space:
mode:
authorStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
committerStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
commitfede6ca1dd0077ff509d84bd24028cc7a93bb119 (patch)
treeaf7f6e759b5dec4fc2964daed09e903958b919ed /Codes/constant.py
first commit
Diffstat (limited to 'Codes/constant.py')
-rw-r--r--Codes/constant.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/Codes/constant.py b/Codes/constant.py
new file mode 100644
index 0000000..eafeab9
--- /dev/null
+++ b/Codes/constant.py
@@ -0,0 +1,153 @@
+import os
+import argparse
+import configparser
+
+
+def get_dir(directory):
+ """
+ get the directory, if no such directory, then make it.
+
+ @param directory: The new directory.
+ """
+
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ return directory
+
+
+def parser_args():
+ parser = argparse.ArgumentParser(description='Options to run the network.')
+ parser.add_argument('-g', '--gpu', type=str, default='0',
+ help='the device id of gpu.')
+ parser.add_argument('-i', '--iters', type=int, default=1,
+ help='set the number of iterations, default is 1')
+ parser.add_argument('-b', '--batch', type=int, default=4,
+ help='set the batch size, default is 4.')
+ parser.add_argument('--num_his', type=int, default=4,
+ help='set the time steps, default is 4.')
+
+ parser.add_argument('-d', '--dataset', type=str,
+ help='the name of dataset.')
+ parser.add_argument('--train_folder', type=str, default='',
+ help='set the training folder path.')
+ parser.add_argument('--test_folder', type=str, default='',
+ help='set the testing folder path.')
+
+ parser.add_argument('--config', type=str, default='training_hyper_params/hyper_params.ini',
+ help='the path of training_hyper_params, default is training_hyper_params/hyper_params.ini')
+
+ parser.add_argument('--snapshot_dir', type=str, default='',
+ help='if it is folder, then it is the directory to save models, '
+ 'if it is a specific model.ckpt-xxx, then the system will load it for testing.')
+ parser.add_argument('--summary_dir', type=str, default='', help='the directory to save summaries.')
+ parser.add_argument('--psnr_dir', type=str, default='', help='the directory to save psnrs results in testing.')
+
+ parser.add_argument('--evaluate', type=str, default='compute_auc',
+ help='the evaluation metric, default is compute_auc')
+
+ return parser.parse_args()
+
+
+class Const(object):
+ class ConstError(TypeError):
+ pass
+
+ class ConstCaseError(ConstError):
+ pass
+
+ def __setattr__(self, name, value):
+ if name in self.__dict__:
+ raise self.ConstError("Can't change const.{}".format(name))
+ if not name.isupper():
+ raise self.ConstCaseError('const name {} is not all uppercase'.format(name))
+
+ self.__dict__[name] = value
+
+ def __str__(self):
+ _str = '<================ Constants information ================>\n'
+ for name, value in self.__dict__.items():
+ print(name, value)
+ _str += '\t{}\t{}\n'.format(name, value)
+
+ return _str
+
+
+args = parser_args()
+const = Const()
+
+# inputs constants
+const.DATASET = args.dataset
+const.TRAIN_FOLDER = args.train_folder
+const.TEST_FOLDER = args.test_folder
+
+const.GPU = args.gpu
+
+const.BATCH_SIZE = args.batch
+const.NUM_HIS = args.num_his
+const.ITERATIONS = args.iters
+
+const.EVALUATE = args.evaluate
+
+# network constants
+const.HEIGHT = 256
+const.WIDTH = 256
+const.FLOWNET_CHECKPOINT = 'flownet2/checkpoints/FlowNetSD/flownet-SD.ckpt-0'
+const.FLOW_HEIGHT = 384
+const.FLOW_WIDTH = 512
+
+# set training hyper-parameters of different datasets
+config = configparser.ConfigParser()
+assert config.read(args.config)
+
+# for lp loss. e.g, 1 or 2 for l1 and l2 loss, respectively)
+const.L_NUM = config.getint(const.DATASET, 'L_NUM')
+# the power to which each gradient term is raised in GDL loss
+const.ALPHA_NUM = config.getint(const.DATASET, 'ALPHA_NUM')
+# the percentage of the adversarial loss to use in the combined loss
+const.LAM_ADV = config.getfloat(const.DATASET, 'LAM_ADV')
+# the percentage of the lp loss to use in the combined loss
+const.LAM_LP = config.getfloat(const.DATASET, 'LAM_LP')
+# the percentage of the GDL loss to use in the combined loss
+const.LAM_GDL = config.getfloat(const.DATASET, 'LAM_GDL')
+# the percentage of the different frame loss
+const.LAM_FLOW = config.getfloat(const.DATASET, 'LAM_FLOW')
+
+# Learning rate of generator
+const.LRATE_G = eval(config.get(const.DATASET, 'LRATE_G'))
+const.LRATE_G_BOUNDARIES = eval(config.get(const.DATASET, 'LRATE_G_BOUNDARIES'))
+
+# Learning rate of discriminator
+const.LRATE_D = eval(config.get(const.DATASET, 'LRATE_D'))
+const.LRATE_D_BOUNDARIES = eval(config.get(const.DATASET, 'LRATE_D_BOUNDARIES'))
+
+
+const.SAVE_DIR = '{dataset}_l_{L_NUM}_alpha_{ALPHA_NUM}_lp_{LAM_LP}_' \
+ 'adv_{LAM_ADV}_gdl_{LAM_GDL}_flow_{LAM_FLOW}'.format(dataset=const.DATASET,
+ L_NUM=const.L_NUM,
+ ALPHA_NUM=const.ALPHA_NUM,
+ LAM_LP=const.LAM_LP, LAM_ADV=const.LAM_ADV,
+ LAM_GDL=const.LAM_GDL, LAM_FLOW=const.LAM_FLOW)
+
+if args.snapshot_dir:
+ # if the snapshot_dir is model.ckpt-xxx, which means it is the single model for testing.
+ if os.path.exists(args.snapshot_dir + '.meta') or os.path.exists(args.snapshot_dir + '.data-00000-of-00001') or \
+ os.path.exists(args.snapshot_dir + '.index'):
+ const.SNAPSHOT_DIR = args.snapshot_dir
+ print(const.SNAPSHOT_DIR)
+ else:
+ const.SNAPSHOT_DIR = get_dir(os.path.join('models', const.SAVE_DIR + '_' + args.snapshot_dir))
+else:
+ const.SNAPSHOT_DIR = get_dir(os.path.join('models', const.SAVE_DIR))
+
+if args.summary_dir:
+ const.SUMMARY_DIR = get_dir(os.path.join('summary', const.SAVE_DIR + '_' + args.summary_dir))
+else:
+ const.SUMMARY_DIR = get_dir(os.path.join('summary', const.SAVE_DIR))
+
+if args.psnr_dir:
+ const.PSNR_DIR = get_dir(os.path.join('psnrs', const.SAVE_DIR + '_' + args.psnr_dir))
+else:
+ const.PSNR_DIR = get_dir(os.path.join('psnrs', const.SAVE_DIR))
+
+