summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/flownet_css/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/flownet_css/train.py')
-rw-r--r--Codes/flownet2/src/flownet_css/train.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_css/train.py b/Codes/flownet2/src/flownet_css/train.py
new file mode 100644
index 0000000..2964f3e
--- /dev/null
+++ b/Codes/flownet2/src/flownet_css/train.py
@@ -0,0 +1,22 @@
+from ..dataloader import load_batch
+from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG
+from ..training_schedules import LONG_SCHEDULE
+from .flownet_css import FlowNetCSS
+
+# Create a new network
+net = FlowNetCSS()
+
+# Load a batch of data
+input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step)
+
+# Train on the data
+net.train(
+ log_dir='./logs/flownet_css',
+ training_schedule=LONG_SCHEDULE,
+ input_a=input_a,
+ input_b=input_b,
+ flow=flow,
+ # Load trained weights for CS part of network
+ checkpoints={
+ './checkpoints/FlowNetCS/flownet-CS.ckpt-0': ('FlowNetCSS/FlowNetCS', 'FlowNetCSS')}
+)