summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-12 03:44:43 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-12 03:44:43 +0100
commit85b8aea622c973a5e1643b04c13d39719fefca0e (patch)
tree56b5e7c6556fb1f1b368a72b1f25be20b0edd97d
parentd55b87d84bc4bf680a0b6dbf20907eb05c7d19ea (diff)
new checkpoint reader
-rw-r--r--cli/app/utils/tf_utils.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/cli/app/utils/tf_utils.py b/cli/app/utils/tf_utils.py
index 0453644..9c12dfc 100644
--- a/cli/app/utils/tf_utils.py
+++ b/cli/app/utils/tf_utils.py
@@ -1,9 +1,9 @@
-from tensorflow.train import py_checkpoint_reader
+from tensorflow.train import NewCheckpointReader
def read_checkpoint(fp_ckpt, key):
- reader = py_checkpoint_reader.NewCheckpointReader(fp_ckpt)
- var_to_shape_map = reader.get_variable_to_shape_map()
- var_to_dtype_map = reader.get_variable_to_dtype_map()
+ reader = NewCheckpointReader(fp_ckpt)
+ # var_to_shape_map = reader.get_variable_to_shape_map()
+ # var_to_dtype_map = reader.get_variable_to_dtype_map()
# for key, value in sorted(var_to_shape_map.items()):
# print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value))
# print(reader.get_tensor(key))