summaryrefslogtreecommitdiff
path: root/cli/app/utils/tf_utils.py
blob: 69a38e477bcbbb1d226855568519326badec0841 (plain)
1
2
3
4
5
6
7
8
9
10
from tensorflow.python.training import py_checkpoint_reader

def read_checkpoint(fp_ckpt, key):
  reader = py_checkpoint_reader.NewCheckpointReader(fp_ckpt)
  var_to_shape_map = reader.get_variable_to_shape_map()
  var_to_dtype_map = reader.get_variable_to_dtype_map()
  # for key, value in sorted(var_to_shape_map.items()):
  #   print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value))
  #   print(reader.get_tensor(key))
  return reader.get_tensor(key)