summaryrefslogtreecommitdiff
path: root/neural_style.py
diff options
context:
space:
mode:
authorCameron <cysmith1010@gmail.com>2016-10-09 23:43:44 -0600
committerGitHub <noreply@github.com>2016-10-09 23:43:44 -0600
commitad912ceb7204f18d8cae8e476e4d62892a4a9442 (patch)
treed9e7af787caab5f7cb05cfe7a360795a04d02077 /neural_style.py
parent84822c308c19dccad2c43d6b887bf498ac4729d0 (diff)
Fixed optimizer bug
Diffstat (limited to 'neural_style.py')
-rw-r--r--neural_style.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/neural_style.py b/neural_style.py
index fffa51f..c3b69df 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -582,7 +582,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
optimizer = get_optimizer(L_total)
if args.optimizer == 'adam':
- minimize_with_adam(sess, net, optimizer, init_img)
+ minimize_with_adam(sess, net, optimizer, init_img, L_total)
elif args.optimizer == 'lbfgs':
minimize_with_lbfgs(sess, net, optimizer, init_img)
@@ -603,9 +603,9 @@ def minimize_with_lbfgs(sess, net, optimizer, init_img):
sess.run(net['input'].assign(init_img))
optimizer.minimize(sess)
-def minimize_with_adam(sess, net, optimizer, init_img):
+def minimize_with_adam(sess, net, optimizer, init_img, loss):
if args.verbose: print('MINIMIZING LOSS USING: ADAM OPTIMIZER')
- train_op = optimizer.minimize(L_total)
+ train_op = optimizer.minimize(loss)
init_op = tf.initialize_all_variables()
sess.run(init_op)
sess.run(net['input'].assign(init_img))
@@ -634,7 +634,7 @@ def write_video_output(frame, output_img):
def write_image_output(output_img, content_img, style_imgs, init_img):
out_dir = os.path.join(args.img_output_dir, args.img_name)
maybe_make_directory(out_dir)
- img_path = os.path.join(out_dir, 'output.png')
+ img_path = os.path.join(out_dir, args.img_name+'.png')
content_path = os.path.join(out_dir, 'content.png')
init_path = os.path.join(out_dir, 'init.png')
@@ -643,7 +643,7 @@ def write_image_output(output_img, content_img, style_imgs, init_img):
write_image(init_path, init_img)
index = 0
for style_img in style_imgs:
- path = os.path.join(out_dir, str(index)+'_style.png')
+ path = os.path.join(out_dir, 'style_'+str(index)+'.png')
write_image(path, style_img)
index += 1