blob: 9c12dfcfa768461efdc839791ae32033d2082e02 (
plain)
1
2
3
4
5
6
7
8
9
10
|
from tensorflow.train import NewCheckpointReader
def read_checkpoint(fp_ckpt, key):
reader = NewCheckpointReader(fp_ckpt)
# var_to_shape_map = reader.get_variable_to_shape_map()
# var_to_dtype_map = reader.get_variable_to_dtype_map()
# for key, value in sorted(var_to_shape_map.items()):
# print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value))
# print(reader.get_tensor(key))
return reader.get_tensor(key)
|