diff options
Diffstat (limited to 'Codes/flownet2/src/flownet_css/train.py')
| -rw-r--r-- | Codes/flownet2/src/flownet_css/train.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/Codes/flownet2/src/flownet_css/train.py b/Codes/flownet2/src/flownet_css/train.py new file mode 100644 index 0000000..2964f3e --- /dev/null +++ b/Codes/flownet2/src/flownet_css/train.py @@ -0,0 +1,22 @@ +from ..dataloader import load_batch +from ..dataset_configs import FLYING_CHAIRS_DATASET_CONFIG +from ..training_schedules import LONG_SCHEDULE +from .flownet_css import FlowNetCSS + +# Create a new network +net = FlowNetCSS() + +# Load a batch of data +input_a, input_b, flow = load_batch(FLYING_CHAIRS_DATASET_CONFIG, 'sample', net.global_step) + +# Train on the data +net.train( + log_dir='./logs/flownet_css', + training_schedule=LONG_SCHEDULE, + input_a=input_a, + input_b=input_b, + flow=flow, + # Load trained weights for CS part of network + checkpoints={ + './checkpoints/FlowNetCS/flownet-CS.ckpt-0': ('FlowNetCSS/FlowNetCS', 'FlowNetCSS')} +) |
