summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab /train.py
first commit
Diffstat (limited to 'train.py')
-rw-r--r--train.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..e85042f
--- /dev/null
+++ b/train.py
@@ -0,0 +1,52 @@
+import time
+from options.train_options import TrainOptions
+opt = TrainOptions().parse() # set CUDA_VISIBLE_DEVICES before import torch
+
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+from util.visualizer import Visualizer
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+num_train = len(data_loader)
+print('#training images = %d' % num_train)
+
+model = create_model(opt)
+visualizer = Visualizer(opt)
+
+total_steps = 0
+
+for epoch in range(1, opt.niter + opt.niter_decay + 1):
+ epoch_start_time = time.time()
+ for i, data in enumerate(dataset):
+ iter_start_time = time.time()
+ total_steps += opt.batchSize
+ epoch_iter = total_steps % num_train
+ model.set_input(data)
+ model.optimize_parameters()
+
+ if total_steps % opt.display_freq == 0:
+ visualizer.display_current_results(model.get_current_visuals(), epoch)
+
+ if total_steps % opt.print_freq == 0:
+ errors = model.get_current_errors()
+ visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time)
+ if opt.display_id > 0:
+ visualizer.plot_current_errors(epoch, epoch_iter, opt, errors)
+
+ if total_steps % opt.save_latest_freq == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' %
+ (epoch, total_steps))
+ model.save('latest')
+
+ if epoch % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' %
+ (epoch, total_steps))
+ model.save('latest')
+ model.save(epoch)
+
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
+
+ if epoch > opt.niter:
+ model.update_learning_rate()