summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoriver56 <ij@aiascience.com>2018-03-13 13:15:07 +0100
committeriver56 <ij@aiascience.com>2018-03-13 13:15:07 +0100
commitb4ee6eafae19ce7fc2e2036acc8532db8c9d186c (patch)
tree4354bb6ea01a671b206cbb098fa2ac08dc61ed28
parent079da5c02fd99ef35d7cad0e20c2924b7c2bcffd (diff)
Fix multiprocessing for Windows by using the __name__ == '__main__' idiom
-rw-r--r--.gitignore1
-rw-r--r--train.py85
2 files changed, 44 insertions, 42 deletions
diff --git a/.gitignore b/.gitignore
index faba6c9..4fdef3e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,3 +40,4 @@ test/.coverage
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
+.idea
diff --git a/train.py b/train.py
index 61b596a..ee8cff1 100644
--- a/train.py
+++ b/train.py
@@ -4,54 +4,55 @@ from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
-opt = TrainOptions().parse()
-data_loader = CreateDataLoader(opt)
-dataset = data_loader.load_data()
-dataset_size = len(data_loader)
-print('#training images = %d' % dataset_size)
+if __name__ == '__main__':
+ opt = TrainOptions().parse()
+ data_loader = CreateDataLoader(opt)
+ dataset = data_loader.load_data()
+ dataset_size = len(data_loader)
+ print('#training images = %d' % dataset_size)
-model = create_model(opt)
-visualizer = Visualizer(opt)
-total_steps = 0
+ model = create_model(opt)
+ visualizer = Visualizer(opt)
+ total_steps = 0
-for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
- epoch_start_time = time.time()
- iter_data_time = time.time()
- epoch_iter = 0
+ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
+ epoch_start_time = time.time()
+ iter_data_time = time.time()
+ epoch_iter = 0
+
+ for i, data in enumerate(dataset):
+ iter_start_time = time.time()
+ if total_steps % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
+ visualizer.reset()
+ total_steps += opt.batchSize
+ epoch_iter += opt.batchSize
+ model.set_input(data)
+ model.optimize_parameters()
- for i, data in enumerate(dataset):
- iter_start_time = time.time()
- if total_steps % opt.print_freq == 0:
- t_data = iter_start_time - iter_data_time
- visualizer.reset()
- total_steps += opt.batchSize
- epoch_iter += opt.batchSize
- model.set_input(data)
- model.optimize_parameters()
+ if total_steps % opt.display_freq == 0:
+ save_result = total_steps % opt.update_html_freq == 0
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
- if total_steps % opt.display_freq == 0:
- save_result = total_steps % opt.update_html_freq == 0
- visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
+ if total_steps % opt.print_freq == 0:
+ errors = model.get_current_errors()
+ t = (time.time() - iter_start_time) / opt.batchSize
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
+ if opt.display_id > 0:
+ visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
- if total_steps % opt.print_freq == 0:
- errors = model.get_current_errors()
- t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
- if opt.display_id > 0:
- visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
+ if total_steps % opt.save_latest_freq == 0:
+ print('saving the latest model (epoch %d, total_steps %d)' %
+ (epoch, total_steps))
+ model.save('latest')
- if total_steps % opt.save_latest_freq == 0:
- print('saving the latest model (epoch %d, total_steps %d)' %
+ iter_data_time = time.time()
+ if epoch % opt.save_epoch_freq == 0:
+ print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save('latest')
+ model.save(epoch)
- iter_data_time = time.time()
- if epoch % opt.save_epoch_freq == 0:
- print('saving the model at the end of epoch %d, iters %d' %
- (epoch, total_steps))
- model.save('latest')
- model.save(epoch)
-
- print('End of epoch %d / %d \t Time Taken: %d sec' %
- (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
- model.update_learning_rate()
+ print('End of epoch %d / %d \t Time Taken: %d sec' %
+ (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
+ model.update_learning_rate()