summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/flownet_c/train.py
diff options
context:
space:
mode:
authorStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
committerStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
commitfede6ca1dd0077ff509d84bd24028cc7a93bb119 (patch)
treeaf7f6e759b5dec4fc2964daed09e903958b919ed /Codes/flownet2/src/flownet_c/train.py
first commit
Diffstat (limited to 'Codes/flownet2/src/flownet_c/train.py')
-rw-r--r--Codes/flownet2/src/flownet_c/train.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_c/train.py b/Codes/flownet2/src/flownet_c/train.py
new file mode 100644
index 0000000..9296ac7
--- /dev/null
+++ b/Codes/flownet2/src/flownet_c/train.py
@@ -0,0 +1,19 @@
+from ..dataloader import load_batch
+from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG
+from ..training_schedules import LONG_SCHEDULE
+from .flownet_c import FlowNetC
+
+# Create a new network
+net = FlowNetC()
+
+# Load a batch of data
+input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step)
+
+# Train on the data
+net.train(
+ log_dir='./logs/flownet_c',
+ training_schedule=LONG_SCHEDULE,
+ input_a=input_a,
+ input_b=input_b,
+ flow=flow
+)