summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/flownet_sd/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/flownet_sd/train.py')
-rw-r--r--Codes/flownet2/src/flownet_sd/train.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_sd/train.py b/Codes/flownet2/src/flownet_sd/train.py
new file mode 100644
index 0000000..86c64e5
--- /dev/null
+++ b/Codes/flownet2/src/flownet_sd/train.py
@@ -0,0 +1,19 @@
+from ..dataloader import load_batch
+from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG
+from ..training_schedules import LONG_SCHEDULE
+from .flownet_sd import FlowNetSD
+
+# Create a new network
+net = FlowNetSD()
+
+# Load a batch of data
+input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step)
+
+# Train on the data
+net.train(
+ log_dir='./logs/flownet_sd_sample',
+ training_schedule=LONG_SCHEDULE,
+ input_a=input_a,
+ input_b=input_b,
+ flow=flow
+)