summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-10 23:56:36 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-10 23:56:36 +0100
commit1cdbf220659d847fdd3855a62f9cba080347271f (patch)
treed8a28ea0f02e1cb75fe0ba6747aa3f63808d79ab /cli/app/commands
parentac6cac6a6e985f7a87d39a99a721dd37e04768d5 (diff)
store filenames
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index 34ed6a4..47d91f7 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -66,8 +66,9 @@ def cli(ctx, opt_fp_in, opt_dims, opt_video, opt_tag):
out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w')
out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32')
out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32')
-
+ out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
for path, index in enumerate(paths):
+ out_labels[index] = os.path.basename(path)
fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, index)
if opt_video:
export_video(fp_frames)