From c99ce7c4e781712e0252c6127ad1a4e8021cc489 Mon Sep 17 00:00:00 2001 From: junyanz Date: Tue, 18 Apr 2017 03:38:47 -0700 Subject: first commit --- train.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 train.py (limited to 'train.py') diff --git a/train.py b/train.py new file mode 100644 index 0000000..e85042f --- /dev/null +++ b/train.py @@ -0,0 +1,52 @@ +import time +from options.train_options import TrainOptions +opt = TrainOptions().parse() # set CUDA_VISIBLE_DEVICES before import torch + +from data.data_loader import CreateDataLoader +from models.models import create_model +from util.visualizer import Visualizer + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +num_train = len(data_loader) +print('#training images = %d' % num_train) + +model = create_model(opt) +visualizer = Visualizer(opt) + +total_steps = 0 + +for epoch in range(1, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + for i, data in enumerate(dataset): + iter_start_time = time.time() + total_steps += opt.batchSize + epoch_iter = total_steps % num_train + model.set_input(data) + model.optimize_parameters() + + if total_steps % opt.display_freq == 0: + visualizer.display_current_results(model.get_current_visuals(), epoch) + + if total_steps % opt.print_freq == 0: + errors = model.get_current_errors() + visualizer.print_current_errors(epoch, epoch_iter, errors, iter_start_time) + if opt.display_id > 0: + visualizer.plot_current_errors(epoch, epoch_iter, opt, errors) + + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, total_steps)) + model.save('latest') + + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % + (epoch, total_steps)) + model.save('latest') + model.save(epoch) + + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + + if epoch > opt.niter: + model.update_learning_rate() -- cgit v1.2.3-70-g09d2