diff options
| -rw-r--r-- | Code/avg_runner.py | 28 | ||||
| -rw-r--r-- | Code/constants.py | 31 | ||||
| -rw-r--r-- | Code/d_model.py | 13 | ||||
| -rw-r--r-- | Code/g_model.py | 35 | ||||
| -rw-r--r-- | Code/loss_functions.py | 10 | ||||
| -rw-r--r-- | Code/tfutils.py | 12 | ||||
| -rw-r--r-- | Code/utils.py | 34 | ||||
| -rwxr-xr-x | process.sh | 10 | ||||
| -rwxr-xr-x | recursive.sh | 13 | ||||
| -rw-r--r-- | report.txt | 306 | ||||
| -rwxr-xr-x | run.sh | 12 | ||||
| -rw-r--r-- | tf_upgrade.py | 742 |
12 files changed, 1207 insertions, 39 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py index 6809187..ed72b63 100644 --- a/Code/avg_runner.py +++ b/Code/avg_runner.py @@ -3,7 +3,11 @@ import getopt import sys import os -from utils import get_train_batch, get_test_batch +import pprint +pp = pprint.PrettyPrinter(indent=2) + +from glob import glob +from utils import get_train_batch, get_test_batch, get_all_clips import constants as c from g_model import GeneratorModel from d_model import DiscriminatorModel @@ -26,8 +30,11 @@ class AVGRunner: self.num_steps = num_steps self.num_test_rec = num_test_rec - self.sess = tf.Session() - self.summary_writer = tf.train.SummaryWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config = config) + #self.sess = tf.Session() + self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) if c.ADVERSARIAL: print 'Init discriminator...' @@ -97,11 +104,20 @@ class AVGRunner: self.g_model.test_batch( batch, self.global_step, num_rec_out=self.num_test_rec) + def process(self): + """ + Process a directory of images using the generator network. + """ + batch = get_all_clips(c.PROCESS_DIR) + self.g_model.test_batch( + batch, self.global_step, num_rec_out=self.num_test_rec, process_only=True) + def usage(): print 'Options:' print '-l/--load_path= <Relative/path/to/saved/model>' print '-t/--test_dir= <Directory of test images>' + print '-p/--process_dir= <Directory to process>' print '-r/--recursions= <# recursive predictions to make on test>' print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)' print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>' @@ -123,6 +139,7 @@ def main(): load_path = None test_only = False + process_only = False num_test_rec = 1 # number of recursive predictions to make on test num_steps = 1000001 try: @@ -140,6 +157,9 @@ def main(): load_path = arg if opt in ('-t', '--test_dir'): c.set_test_dir(arg) + if opt in ('-p', '--process_dir'): + c.set_process_dir(arg) + process_only = True if opt in ('-r', '--recursions'): num_test_rec = int(arg) if opt in ('-a', '--adversarial'): @@ -177,6 +197,8 @@ def main(): runner = AVGRunner(num_steps, load_path, num_test_rec) if test_only: runner.test() + elif process_only: + runner.process() else: runner.train() diff --git a/Code/constants.py b/Code/constants.py index 761448b..5d5446b 100644 --- a/Code/constants.py +++ b/Code/constants.py @@ -42,18 +42,22 @@ def clear_dir(directory): except Exception as e: print(e) +def get_process_frame_dims(): + img_path = glob(os.path.join(PROCESS_DIR, '*'))[0] + img = imread(img_path, mode='RGB') + shape = np.shape(img) + return shape[0], shape[1] + def get_test_frame_dims(): img_path = glob(os.path.join(TEST_DIR, '*/*'))[0] img = imread(img_path, mode='RGB') shape = np.shape(img) - return shape[0], shape[1] def get_train_frame_dims(): img_path = glob(os.path.join(TRAIN_DIR, '*/*'))[0] img = imread(img_path, mode='RGB') shape = np.shape(img) - return shape[0], shape[1] def set_test_dir(directory): @@ -67,15 +71,28 @@ def set_test_dir(directory): TEST_DIR = directory FULL_HEIGHT, FULL_WIDTH = get_test_frame_dims() +def set_process_dir(directory): + """ + Edits all constants dependent on TEST_DIR. + + @param directory: The new test directory. + """ + global PROCESS_DIR, FULL_HEIGHT, FULL_WIDTH + + TEST_DIR = directory + FULL_HEIGHT, FULL_WIDTH = get_process_frame_dims() + # root directory for all data DATA_DIR = get_dir('../Data/') # directory of unprocessed training frames -TRAIN_DIR = os.path.join(DATA_DIR, 'Ms_Pacman/Train/') +TRAIN_DIR = os.path.join(DATA_DIR, 'WoodFlat/Train/') # directory of unprocessed test frames -TEST_DIR = os.path.join(DATA_DIR, 'Ms_Pacman/Test/') +TEST_DIR = os.path.join(DATA_DIR, 'WoodFlat/Test/') +# directory of all the images we want to process +PROCESS_DIR = os.path.join(DATA_DIR, 'WoodFlat/Process/') # Directory of processed training clips. # hidden so finder doesn't freeze w/ so many files. DON'T USE `ls` COMMAND ON THIS DIR! -TRAIN_DIR_CLIPS = get_dir(os.path.join(DATA_DIR, '.Clips/')) +TRAIN_DIR_CLIPS = get_dir(os.path.join(DATA_DIR, 'Clips/WoodClips/')) # For processing clips. l2 diff between frames must be greater than this MOVEMENT_THRESHOLD = 100 @@ -83,8 +100,8 @@ MOVEMENT_THRESHOLD = 100 NUM_CLIPS = len(glob(TRAIN_DIR_CLIPS + '*')) # the height and width of the full frames to test on. Set in avg_runner.py or process_data.py main. -FULL_HEIGHT = 210 -FULL_WIDTH = 160 +FULL_HEIGHT = 320 +FULL_WIDTH = 180 # the height and width of the patches to train on TRAIN_HEIGHT = TRAIN_WIDTH = 32 diff --git a/Code/d_model.py b/Code/d_model.py index 1345ceb..02b882d 100644 --- a/Code/d_model.py +++ b/Code/d_model.py @@ -89,8 +89,8 @@ class DiscriminatorModel: name='train_op') # add summaries to visualize in TensorBoard - loss_summary = tf.scalar_summary('loss_D', self.global_loss) - self.summaries = tf.merge_summary([loss_summary]) + loss_summary = tf.summary.scalar('loss_D', self.global_loss) + self.summaries = tf.summary.merge([loss_summary]) def build_feed_dict(self, input_frames, gt_output_frames, generator): """ @@ -123,14 +123,19 @@ class DiscriminatorModel: for scale_num in xrange(self.num_scale_nets): scale_net = self.scale_nets[scale_num] + broken = 0 # resize gt_output_frames scaled_gt_output_frames = np.empty([batch_size, scale_net.height, scale_net.width, 3]) for i, img in enumerate(gt_output_frames): # for skimage.transform.resize, images need to be in range [0, 1], so normalize to # [0, 1] before resize and back to [-1, 1] after sknorm_img = (img / 2) + 0.5 - resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3]) - scaled_gt_output_frames[i] = (resized_frame - 0.5) * 2 + try: + resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3]) + scaled_gt_output_frames[i-broken] = (resized_frame - 0.5) * 2 + except: + broken += 1 + #print str(broken) + " " + "broken images" # combine with resized gt_output_frames to get inputs for prediction scaled_input_frames = np.concatenate([g_scale_preds[scale_num], diff --git a/Code/g_model.py b/Code/g_model.py index 5dc8265..8ff75a1 100644 --- a/Code/g_model.py +++ b/Code/g_model.py @@ -114,7 +114,7 @@ class GeneratorModel: if scale_num > 0: last_gen_frames = tf.image.resize_images( last_gen_frames,[scale_height, scale_width]) - inputs = tf.concat(3, [inputs, last_gen_frames]) + inputs = tf.concat(axis=3, values=[inputs, last_gen_frames]) # generated frame predictions preds = inputs @@ -196,7 +196,7 @@ class GeneratorModel: name='train_op') # train loss summary - loss_summary = tf.scalar_summary('train_loss_G', self.global_loss) + loss_summary = tf.summary.scalar('train_loss_G', self.global_loss) self.summaries_train.append(loss_summary) ## @@ -215,22 +215,22 @@ class GeneratorModel: self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1], self.gt_frames_test) # train error summaries - summary_psnr_train = tf.scalar_summary('train_PSNR', + summary_psnr_train = tf.summary.scalar('train_PSNR', self.psnr_error_train) - summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff', + summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff', self.sharpdiff_error_train) self.summaries_train += [summary_psnr_train, summary_sharpdiff_train] # test error - summary_psnr_test = tf.scalar_summary('test_PSNR', + summary_psnr_test = tf.summary.scalar('test_PSNR', self.psnr_error_test) - summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff', + summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff', self.sharpdiff_error_test) self.summaries_test += [summary_psnr_test, summary_sharpdiff_test] # add summaries to visualize in TensorBoard - self.summaries_train = tf.merge_summary(self.summaries_train) - self.summaries_test = tf.merge_summary(self.summaries_test) + self.summaries_train = tf.summary.merge(self.summaries_train) + self.summaries_test = tf.summary.merge(self.summaries_test) def train_step(self, batch, discriminator=None): """ @@ -307,13 +307,17 @@ class GeneratorModel: scale_width = int(self.width_train * scale_factor) # resize gt_output_frames for scale and append to scale_gts_train + broken = 0 scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3]) for i, img in enumerate(gt_frames): # for skimage.transform.resize, images need to be in range [0, 1], so normalize # to [0, 1] before resize and back to [-1, 1] after sknorm_img = (img / 2) + 0.5 - resized_frame = resize(sknorm_img, [scale_height, scale_width, 3]) - scaled_gt_frames[i] = (resized_frame - 0.5) * 2 + try: + resized_frame = resize(sknorm_img, [scale_height, scale_width, 3]) + scaled_gt_frames[i-broken] = (resized_frame - 0.5) * 2 + except: + broken += 1 scale_gts.append(scaled_gt_frames) # for every clip in the batch, save the inputs, scale preds and scale gts @@ -342,7 +346,7 @@ class GeneratorModel: return global_step - def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True): + def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True, process_only=False): """ Runs a training step using the global loss on each of the scale networks. @@ -408,7 +412,14 @@ class GeneratorModel: # Save images ## - if save_imgs: + if process_only: + pred_dir = c.get_dir(os.path.join( + c.IMG_SAVE_DIR, 'Process/Step_' + str(global_step))) + for rec_num in xrange(num_rec_out): + gen_img = rec_preds[rec_num][0] + imsave(os.path.join(pred_dir, 'gen_{:05d}.png'.format(rec_num)), gen_img) + + elif save_imgs: for pred_num in xrange(len(input_frames)): pred_dir = c.get_dir(os.path.join( c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num))) diff --git a/Code/loss_functions.py b/Code/loss_functions.py index 994d226..0d33b49 100644 --- a/Code/loss_functions.py +++ b/Code/loss_functions.py @@ -61,7 +61,7 @@ def lp_loss(gen_frames, gt_frames, l_num): scale_losses.append(tf.reduce_sum(tf.abs(gen_frames[i] - gt_frames[i])**l_num)) # condense into one tensor and avg - return tf.reduce_mean(tf.pack(scale_losses)) + return tf.reduce_mean(tf.stack(scale_losses)) def gdl_loss(gen_frames, gt_frames, alpha): @@ -80,8 +80,8 @@ def gdl_loss(gen_frames, gt_frames, alpha): # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively. pos = tf.constant(np.identity(3), dtype=tf.float32) neg = -1 * pos - filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] - filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1] + filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] strides = [1, 1, 1, 1] # stride of (1, 1) padding = 'SAME' @@ -96,7 +96,7 @@ def gdl_loss(gen_frames, gt_frames, alpha): scale_losses.append(tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha))) # condense into one tensor and avg - return tf.reduce_mean(tf.pack(scale_losses)) + return tf.reduce_mean(tf.stack(scale_losses)) def adv_loss(preds, labels): @@ -115,4 +115,4 @@ def adv_loss(preds, labels): scale_losses.append(loss) # condense into one tensor and avg - return tf.reduce_mean(tf.pack(scale_losses)) + return tf.reduce_mean(tf.stack(scale_losses)) diff --git a/Code/tfutils.py b/Code/tfutils.py index 22baf95..0d2d8dd 100644 --- a/Code/tfutils.py +++ b/Code/tfutils.py @@ -92,12 +92,12 @@ def batch_pad_to_bounding_box(images, offset_height, offset_width, target_height rpad = np.zeros([batch_size, target_height, num_rpad, channels]) padded = images - if num_tpad > 0 and num_bpad > 0: padded = tf.concat(1, [tpad, padded, bpad]) - elif num_tpad > 0: padded = tf.concat(1, [tpad, padded]) - elif num_bpad > 0: padded = tf.concat(1, [padded, bpad]) - if num_lpad > 0 and num_rpad > 0: padded = tf.concat(2, [lpad, padded, rpad]) - elif num_lpad > 0: padded = tf.concat(2, [lpad, padded]) - elif num_rpad > 0: padded = tf.concat(2, [padded, rpad]) + if num_tpad > 0 and num_bpad > 0: padded = tf.concat(axis=1, values=[tpad, padded, bpad]) + elif num_tpad > 0: padded = tf.concat(axis=1, values=[tpad, padded]) + elif num_bpad > 0: padded = tf.concat(axis=1, values=[padded, bpad]) + if num_lpad > 0 and num_rpad > 0: padded = tf.concat(axis=2, values=[lpad, padded, rpad]) + elif num_lpad > 0: padded = tf.concat(axis=2, values=[lpad, padded]) + elif num_rpad > 0: padded = tf.concat(axis=2, values=[padded, rpad]) return padded diff --git a/Code/utils.py b/Code/utils.py index 39a7e11..d5c5d05 100644 --- a/Code/utils.py +++ b/Code/utils.py @@ -90,6 +90,36 @@ def get_full_clips(data_dir, num_clips, num_rec_out=1): return clips +def get_all_clips(data_dir): + """ + Loads all clips from a directory. + + @param data_dir: The directory of the data to read. Should be either c.TRAIN_DIR or c.TEST_DIR. + @param num_clips: The number of clips to read. + @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively, + using the previously-generated frames as input. Default = 1. + + @return: An array of shape + [num_clips, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + num_rec_out))]. + A batch of frame sequences with values normalized in range [-1, 1]. + """ + # get all the clips + clip_frame_paths = sorted(glob(os.path.join(data_dir, '*'))) + + clips = np.empty([len(clip_frame_paths), + c.FULL_HEIGHT, + c.FULL_WIDTH, + (3 * (c.HIST_LEN + num_rec_out))]) + + # read in frames + for frame_num, frame_path in enumerate(clip_frame_paths): + frame = imread(frame_path, mode='RGB') + norm_frame = normalize_frames(frame) + + clips[clip_num, :, :, frame_num * 3:(frame_num + 1) * 3] = norm_frame + + return clips + def process_clip(): """ Gets a clip from the train dataset, cropped randomly to c.TRAIN_HEIGHT x c.TRAIN_WIDTH. @@ -193,8 +223,8 @@ def sharp_diff_error(gen_frames, gt_frames): # TODO: Could this be simplified with one filter [[-1, 2], [0, -1]]? pos = tf.constant(np.identity(3), dtype=tf.float32) neg = -1 * pos - filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] - filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1] + filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] strides = [1, 1, 1, 1] # stride of (1, 1) padding = 'SAME' diff --git a/process.sh b/process.sh new file mode 100755 index 0000000..1967c22 --- /dev/null +++ b/process.sh @@ -0,0 +1,10 @@ +cd Code + +python process_data.py \ + --num_clips=5000000 \ + --train_dir=../Data/GoodFlat/Train/ \ + --clips_dir=../Data/Clips/GoodFlat/ \ + --overwrite + +cd .. + diff --git a/recursive.sh b/recursive.sh new file mode 100755 index 0000000..304e8c1 --- /dev/null +++ b/recursive.sh @@ -0,0 +1,13 @@ +cd Code + +python avg_runner.py \ + --test_only \ + --recursions 16 \ + --name=woodclips \ + --load_path=/home/lens/code/Adversarial_Video_Generation/Save/Models/woodclipsmodel.ckpt-60000 \ + --test_dir=/home/lens/code/Adversarial_Video_Generation/Data/WoodFlat/ + +# --train_dir=../Data/WoodFlat/ \ + +cd .. + diff --git a/report.txt b/report.txt new file mode 100644 index 0000000..ef6775f --- /dev/null +++ b/report.txt @@ -0,0 +1,306 @@ +================================================================================ +Input tree: 'Code' +================================================================================ +-------------------------------------------------------------------------------- +Processing file 'Code/tfutils_test.py' + outputting to 'Code2/tfutils_test.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/d_scale_model.py' + outputting to 'Code2/d_scale_model.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/constants.py' + outputting to 'Code2/constants.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/loss_functions.py' + outputting to 'Code2/loss_functions.py' +-------------------------------------------------------------------------------- + +'Code/loss_functions.py' Line 64 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: return tf.reduce_mean(tf.pack(scale_losses)) + ~~~~~~~ + New: return tf.reduce_mean(tf.stack(scale_losses)) + ~~~~~~~~ + +'Code/loss_functions.py' Line 99 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: return tf.reduce_mean(tf.pack(scale_losses)) + ~~~~~~~ + New: return tf.reduce_mean(tf.stack(scale_losses)) + ~~~~~~~~ + +'Code/loss_functions.py' Line 83 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] + ~~~~~~~ + New: filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1] + ~~~~~~~~ + +'Code/loss_functions.py' Line 84 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + ~~~~~~~ + New: filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + ~~~~~~~~ + +'Code/loss_functions.py' Line 118 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: return tf.reduce_mean(tf.pack(scale_losses)) + ~~~~~~~ + New: return tf.reduce_mean(tf.stack(scale_losses)) + ~~~~~~~~ + + +-------------------------------------------------------------------------------- +Processing file 'Code/avg_runner.py' + outputting to 'Code2/avg_runner.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/process_data.py' + outputting to 'Code2/process_data.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/tfutils.py' + outputting to 'Code2/tfutils.py' +-------------------------------------------------------------------------------- + +'Code/tfutils.py' Line 96 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: elif num_tpad > 0: padded = tf.concat(1, [tpad, padded]) + + New: elif num_tpad > 0: padded = tf.concat(axis=1, values=[tpad, padded]) + ~~~~~ ~~~~~~~ + +'Code/tfutils.py' Line 97 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: elif num_bpad > 0: padded = tf.concat(1, [padded, bpad]) + + New: elif num_bpad > 0: padded = tf.concat(axis=1, values=[padded, bpad]) + ~~~~~ ~~~~~~~ + +'Code/tfutils.py' Line 98 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: if num_lpad > 0 and num_rpad > 0: padded = tf.concat(2, [lpad, padded, rpad]) + + New: if num_lpad > 0 and num_rpad > 0: padded = tf.concat(axis=2, values=[lpad, padded, rpad]) + ~~~~~ ~~~~~~~ + +'Code/tfutils.py' Line 99 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: elif num_lpad > 0: padded = tf.concat(2, [lpad, padded]) + + New: elif num_lpad > 0: padded = tf.concat(axis=2, values=[lpad, padded]) + ~~~~~ ~~~~~~~ + +'Code/tfutils.py' Line 100 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: elif num_rpad > 0: padded = tf.concat(2, [padded, rpad]) + + New: elif num_rpad > 0: padded = tf.concat(axis=2, values=[padded, rpad]) + ~~~~~ ~~~~~~~ + +'Code/tfutils.py' Line 95 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: if num_tpad > 0 and num_bpad > 0: padded = tf.concat(1, [tpad, padded, bpad]) + + New: if num_tpad > 0 and num_bpad > 0: padded = tf.concat(axis=1, values=[tpad, padded, bpad]) + ~~~~~ ~~~~~~~ + + +-------------------------------------------------------------------------------- +Processing file 'Code/loss_functions_test.py' + outputting to 'Code2/loss_functions_test.py' +-------------------------------------------------------------------------------- + + +-------------------------------------------------------------------------------- +Processing file 'Code/g_model.py' + outputting to 'Code2/g_model.py' +-------------------------------------------------------------------------------- + +'Code/g_model.py' Line 225 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: summary_psnr_test = tf.scalar_summary('test_PSNR', + ~~~~~~~~~~~~~~~~~ + New: summary_psnr_test = tf.summary.scalar('test_PSNR', + ~~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 227 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff', + ~~~~~~~~~~~~~~~~~ + New: summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff', + ~~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 199 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: loss_summary = tf.scalar_summary('train_loss_G', self.global_loss) + ~~~~~~~~~~~~~~~~~ + New: loss_summary = tf.summary.scalar('train_loss_G', self.global_loss) + ~~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 232 +-------------------------------------------------------------------------------- + +Renamed function 'tf.merge_summary' to 'tf.summary.merge' + + Old: self.summaries_train = tf.merge_summary(self.summaries_train) + ~~~~~~~~~~~~~~~~ + New: self.summaries_train = tf.summary.merge(self.summaries_train) + ~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 233 +-------------------------------------------------------------------------------- + +Renamed function 'tf.merge_summary' to 'tf.summary.merge' + + Old: self.summaries_test = tf.merge_summary(self.summaries_test) + ~~~~~~~~~~~~~~~~ + New: self.summaries_test = tf.summary.merge(self.summaries_test) + ~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 117 +-------------------------------------------------------------------------------- + +Added keyword 'concat_dim' to reordered function 'tf.concat' +Added keyword 'values' to reordered function 'tf.concat' + + Old: inputs = tf.concat(3, [inputs, last_gen_frames]) + + New: inputs = tf.concat(axis=3, values=[inputs, last_gen_frames]) + ~~~~~ ~~~~~~~ + +'Code/g_model.py' Line 218 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: summary_psnr_train = tf.scalar_summary('train_PSNR', + ~~~~~~~~~~~~~~~~~ + New: summary_psnr_train = tf.summary.scalar('train_PSNR', + ~~~~~~~~~~~~~~~~~ + +'Code/g_model.py' Line 220 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff', + ~~~~~~~~~~~~~~~~~ + New: summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff', + ~~~~~~~~~~~~~~~~~ + + +-------------------------------------------------------------------------------- +Processing file 'Code/utils.py' + outputting to 'Code2/utils.py' +-------------------------------------------------------------------------------- + +'Code/utils.py' Line 196 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] + ~~~~~~~ + New: filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1] + ~~~~~~~~ + +'Code/utils.py' Line 197 +-------------------------------------------------------------------------------- + +Renamed function 'tf.pack' to 'tf.stack' + + Old: filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + ~~~~~~~ + New: filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + ~~~~~~~~ + + +-------------------------------------------------------------------------------- +Processing file 'Code/d_model.py' + outputting to 'Code2/d_model.py' +-------------------------------------------------------------------------------- + +'Code/d_model.py' Line 92 +-------------------------------------------------------------------------------- + +Renamed function 'tf.scalar_summary' to 'tf.summary.scalar' + + Old: loss_summary = tf.scalar_summary('loss_D', self.global_loss) + ~~~~~~~~~~~~~~~~~ + New: loss_summary = tf.summary.scalar('loss_D', self.global_loss) + ~~~~~~~~~~~~~~~~~ + +'Code/d_model.py' Line 93 +-------------------------------------------------------------------------------- + +Renamed function 'tf.merge_summary' to 'tf.summary.merge' + + Old: self.summaries = tf.merge_summary([loss_summary]) + ~~~~~~~~~~~~~~~~ + New: self.summaries = tf.summary.merge([loss_summary]) + ~~~~~~~~~~~~~~~~ + + @@ -0,0 +1,12 @@ +cd Code + +python avg_runner.py \ + --recursions=16 \ + --name=woodclips \ + --load_path=/home/lens/code/Adversarial_Video_Generation/Save/Models/woodclipsmodel.ckpt-60000 \ + --test_dir=/home/lens/code/Adversarial_Video_Generation/Data/WoodFlat/ + +# --train_dir=../Data/WoodFlat/ \ + +cd .. + diff --git a/tf_upgrade.py b/tf_upgrade.py new file mode 100644 index 0000000..1f88335 --- /dev/null +++ b/tf_upgrade.py @@ -0,0 +1,742 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Upgrader for Python scripts from pre-1.0 TensorFlow to 1.0 TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import ast +import collections +import os +import shutil +import sys +import tempfile +import traceback + + +class APIChangeSpec(object): + """This class defines the transformations that need to happen. + + This class must provide the following fields: + + * `function_keyword_renames`: maps function names to a map of old -> new + argument names + * `function_renames`: maps function names to new function names + * `change_to_function`: a set of function names that have changed (for + notifications) + * `function_reorders`: maps functions whose argument order has changed to the + list of arguments in the new order + * `function_handle`: maps function names to custom handlers for the function + + For an example, see `TFAPIChangeSpec`. + """ + + +class _FileEditTuple( + collections.namedtuple("_FileEditTuple", + ["comment", "line", "start", "old", "new"])): + """Each edit that is recorded by a _FileEditRecorder. + + Fields: + comment: A description of the edit and why it was made. + line: The line number in the file where the edit occurs (1-indexed). + start: The line number in the file where the edit occurs (0-indexed). + old: text string to remove (this must match what was in file). + new: text string to add in place of `old`. + """ + + __slots__ = () + + +class _FileEditRecorder(object): + """Record changes that need to be done to the file.""" + + def __init__(self, filename): + # all edits are lists of chars + self._filename = filename + + self._line_to_edit = collections.defaultdict(list) + self._errors = [] + + def process(self, text): + """Process a list of strings, each corresponding to the recorded changes. + + Args: + text: A list of lines of text (assumed to contain newlines) + Returns: + A tuple of the modified text and a textual description of what is done. + Raises: + ValueError: if substitution source location does not have expected text. + """ + + change_report = "" + + # Iterate of each line + for line, edits in self._line_to_edit.items(): + offset = 0 + # sort by column so that edits are processed in order in order to make + # indexing adjustments cumulative for changes that change the string + # length + edits.sort(key=lambda x: x.start) + + # Extract each line to a list of characters, because mutable lists + # are editable, unlike immutable strings. + char_array = list(text[line - 1]) + + # Record a description of the change + change_report += "%r Line %d\n" % (self._filename, line) + change_report += "-" * 80 + "\n\n" + for e in edits: + change_report += "%s\n" % e.comment + change_report += "\n Old: %s" % (text[line - 1]) + + # Make underscore buffers for underlining where in the line the edit was + change_list = [" "] * len(text[line - 1]) + change_list_new = [" "] * len(text[line - 1]) + + # Iterate for each edit + for e in edits: + # Create effective start, end by accounting for change in length due + # to previous edits + start_eff = e.start + offset + end_eff = start_eff + len(e.old) + + # Make sure the edit is changing what it should be changing + old_actual = "".join(char_array[start_eff:end_eff]) + if old_actual != e.old: + raise ValueError("Expected text %r but got %r" % + ("".join(e.old), "".join(old_actual))) + # Make the edit + char_array[start_eff:end_eff] = list(e.new) + + # Create the underline highlighting of the before and after + change_list[e.start:e.start + len(e.old)] = "~" * len(e.old) + change_list_new[start_eff:end_eff] = "~" * len(e.new) + + # Keep track of how to generate effective ranges + offset += len(e.new) - len(e.old) + + # Finish the report comment + change_report += " %s\n" % "".join(change_list) + text[line - 1] = "".join(char_array) + change_report += " New: %s" % (text[line - 1]) + change_report += " %s\n\n" % "".join(change_list_new) + return "".join(text), change_report, self._errors + + def add(self, comment, line, start, old, new, error=None): + """Add a new change that is needed. + + Args: + comment: A description of what was changed + line: Line number (1 indexed) + start: Column offset (0 indexed) + old: old text + new: new text + error: this "edit" is something that cannot be fixed automatically + Returns: + None + """ + + self._line_to_edit[line].append( + _FileEditTuple(comment, line, start, old, new)) + if error: + self._errors.append("%s:%d: %s" % (self._filename, line, error)) + + +class _ASTCallVisitor(ast.NodeVisitor): + """AST Visitor that processes function calls. + + Updates function calls from old API version to new API version using a given + change spec. + """ + + def __init__(self, filename, lines, api_change_spec): + self._filename = filename + self._file_edit = _FileEditRecorder(filename) + self._lines = lines + self._api_change_spec = api_change_spec + + def process(self, lines): + return self._file_edit.process(lines) + + def generic_visit(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def _rename_functions(self, node, full_name): + function_renames = self._api_change_spec.function_renames + try: + new_name = function_renames[full_name] + self._file_edit.add("Renamed function %r to %r" % (full_name, new_name), + node.lineno, node.col_offset, full_name, new_name) + except KeyError: + pass + + def _get_attribute_full_path(self, node): + """Traverse an attribute to generate a full name e.g. tf.foo.bar. + + Args: + node: A Node of type Attribute. + + Returns: + a '.'-delimited full-name or None if the tree was not a simple form. + i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c". + """ + curr = node + items = [] + while not isinstance(curr, ast.Name): + if not isinstance(curr, ast.Attribute): + return None + items.append(curr.attr) + curr = curr.value + items.append(curr.id) + return ".".join(reversed(items)) + + def _find_true_position(self, node): + """Return correct line number and column offset for a given node. + + This is necessary mainly because ListComp's location reporting reports + the next token after the list comprehension list opening. + + Args: + node: Node for which we wish to know the lineno and col_offset + """ + import re + find_open = re.compile("^\s*(\\[).*$") + find_string_chars = re.compile("['\"]") + + if isinstance(node, ast.ListComp): + # Strangely, ast.ListComp returns the col_offset of the first token + # after the '[' token which appears to be a bug. Workaround by + # explicitly finding the real start of the list comprehension. + line = node.lineno + col = node.col_offset + # loop over lines + while 1: + # Reverse the text to and regular expression search for whitespace + text = self._lines[line - 1] + reversed_preceding_text = text[:col][::-1] + # First find if a [ can be found with only whitespace between it and + # col. + m = find_open.match(reversed_preceding_text) + if m: + new_col_offset = col - m.start(1) - 1 + return line, new_col_offset + else: + if (reversed_preceding_text == "" or + reversed_preceding_text.isspace()): + line = line - 1 + prev_line = self._lines[line - 1] + # TODO(aselle): + # this is poor comment detection, but it is good enough for + # cases where the comment does not contain string literal starting/ + # ending characters. If ast gave us start and end locations of the + # ast nodes rather than just start, we could use string literal + # node ranges to filter out spurious #'s that appear in string + # literals. + comment_start = prev_line.find("#") + if comment_start == -1: + col = len(prev_line) - 1 + elif find_string_chars.search(prev_line[comment_start:]) is None: + col = comment_start + else: + return None, None + else: + return None, None + # Most other nodes return proper locations (with notably does not), but + # it is not possible to use that in an argument. + return node.lineno, node.col_offset + + def visit_Call(self, node): # pylint: disable=invalid-name + """Handle visiting a call node in the AST. + + Args: + node: Current Node + """ + + # Find a simple attribute name path e.g. "tf.foo.bar" + full_name = self._get_attribute_full_path(node.func) + + # Make sure the func is marked as being part of a call + node.func.is_function_for_call = True + + if full_name: + # Call special handlers + function_handles = self._api_change_spec.function_handle + if full_name in function_handles: + function_handles[full_name](self._file_edit, node) + + # Examine any non-keyword argument and make it into a keyword argument + # if reordering required. + function_reorders = self._api_change_spec.function_reorders + function_keyword_renames = ( + self._api_change_spec.function_keyword_renames) + + if full_name in function_reorders: + reordered = function_reorders[full_name] + for idx, arg in enumerate(node.args): + lineno, col_offset = self._find_true_position(arg) + if lineno is None or col_offset is None: + self._file_edit.add( + "Failed to add keyword %r to reordered function %r" % + (reordered[idx], full_name), + arg.lineno, + arg.col_offset, + "", + "", + error="A necessary keyword argument failed to be inserted.") + else: + keyword_arg = reordered[idx] + if (full_name in function_keyword_renames and + keyword_arg in function_keyword_renames[full_name]): + keyword_arg = function_keyword_renames[full_name][keyword_arg] + self._file_edit.add("Added keyword %r to reordered function %r" % + (reordered[idx], full_name), lineno, col_offset, + "", keyword_arg + "=") + + # Examine each keyword argument and convert it to the final renamed form + renamed_keywords = ({} if full_name not in function_keyword_renames else + function_keyword_renames[full_name]) + for keyword in node.keywords: + argkey = keyword.arg + argval = keyword.value + + if argkey in renamed_keywords: + argval_lineno, argval_col_offset = self._find_true_position(argval) + if argval_lineno is not None and argval_col_offset is not None: + # TODO(aselle): We should scan backward to find the start of the + # keyword key. Unfortunately ast does not give you the location of + # keyword keys, so we are forced to infer it from the keyword arg + # value. + key_start = argval_col_offset - len(argkey) - 1 + key_end = key_start + len(argkey) + 1 + if (self._lines[argval_lineno - 1][key_start:key_end] == argkey + + "="): + self._file_edit.add("Renamed keyword argument from %r to %r" % + (argkey, + renamed_keywords[argkey]), argval_lineno, + argval_col_offset - len(argkey) - 1, + argkey + "=", renamed_keywords[argkey] + "=") + continue + self._file_edit.add( + "Failed to rename keyword argument from %r to %r" % + (argkey, renamed_keywords[argkey]), + argval.lineno, + argval.col_offset - len(argkey) - 1, + "", + "", + error="Failed to find keyword lexographically. Fix manually.") + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Attribute(self, node): # pylint: disable=invalid-name + """Handle bare Attributes i.e. [tf.foo, tf.bar]. + + Args: + node: Node that is of type ast.Attribute + """ + full_name = self._get_attribute_full_path(node) + if full_name: + self._rename_functions(node, full_name) + if full_name in self._api_change_spec.change_to_function: + if not hasattr(node, "is_function_for_call"): + new_text = full_name + "()" + self._file_edit.add("Changed %r to %r" % (full_name, new_text), + node.lineno, node.col_offset, full_name, new_text) + + ast.NodeVisitor.generic_visit(self, node) + + +class ASTCodeUpgrader(object): + """Handles upgrading a set of Python files using a given API change spec.""" + + def __init__(self, api_change_spec): + if not isinstance(api_change_spec, APIChangeSpec): + raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" % + type(api_change_spec)) + self._api_change_spec = api_change_spec + + def process_file(self, in_filename, out_filename): + """Process the given python file for incompatible changes. + + Args: + in_filename: filename to parse + out_filename: output file to write to + Returns: + A tuple representing number of files processed, log of actions, errors + """ + + # Write to a temporary file, just in case we are doing an implace modify. + with open(in_filename, "r") as in_file, \ + tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + ret = self.process_opened_file(in_filename, in_file, out_filename, + temp_file) + + shutil.move(temp_file.name, out_filename) + return ret + + # Broad exceptions are required here because ast throws whatever it wants. + # pylint: disable=broad-except + def process_opened_file(self, in_filename, in_file, out_filename, out_file): + """Process the given python file for incompatible changes. + + This function is split out to facilitate StringIO testing from + tf_upgrade_test.py. + + Args: + in_filename: filename to parse + in_file: opened file (or StringIO) + out_filename: output file to write to + out_file: opened file (or StringIO) + Returns: + A tuple representing number of files processed, log of actions, errors + """ + process_errors = [] + text = "-" * 80 + "\n" + text += "Processing file %r\n outputting to %r\n" % (in_filename, + out_filename) + text += "-" * 80 + "\n\n" + + parsed_ast = None + lines = in_file.readlines() + try: + parsed_ast = ast.parse("".join(lines)) + except Exception: + text += "Failed to parse %r\n\n" % in_filename + text += traceback.format_exc() + if parsed_ast: + visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec) + visitor.visit(parsed_ast) + out_text, new_text, process_errors = visitor.process(lines) + text += new_text + if out_file: + out_file.write(out_text) + text += "\n" + return 1, text, process_errors + + # pylint: enable=broad-except + + def process_tree(self, root_directory, output_root_directory, + copy_other_files): + """Processes upgrades on an entire tree of python files in place. + + Note that only Python files. If you have custom code in other languages, + you will need to manually upgrade those. + + Args: + root_directory: Directory to walk and process. + output_root_directory: Directory to use as base. + copy_other_files: Copy files that are not touched by this converter. + + Returns: + A tuple of files processed, the report string ofr all files, and errors + """ + + # make sure output directory doesn't exist + if output_root_directory and os.path.exists(output_root_directory): + print("Output directory %r must not already exist." % + (output_root_directory)) + sys.exit(1) + + # make sure output directory does not overlap with root_directory + norm_root = os.path.split(os.path.normpath(root_directory)) + norm_output = os.path.split(os.path.normpath(output_root_directory)) + if norm_root == norm_output: + print("Output directory %r same as input directory %r" % + (root_directory, output_root_directory)) + sys.exit(1) + + # Collect list of files to process (we do this to correctly handle if the + # user puts the output directory in some sub directory of the input dir) + files_to_process = [] + files_to_copy = [] + for dir_name, _, file_list in os.walk(root_directory): + py_files = [f for f in file_list if f.endswith(".py")] + copy_files = [f for f in file_list if not f.endswith(".py")] + for filename in py_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath(fullpath, + root_directory)) + files_to_process.append((fullpath, fullpath_output)) + if copy_other_files: + for filename in copy_files: + fullpath = os.path.join(dir_name, filename) + fullpath_output = os.path.join(output_root_directory, + os.path.relpath( + fullpath, root_directory)) + files_to_copy.append((fullpath, fullpath_output)) + + file_count = 0 + tree_errors = [] + report = "" + report += ("=" * 80) + "\n" + report += "Input tree: %r\n" % root_directory + report += ("=" * 80) + "\n" + + for input_path, output_path in files_to_process: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + file_count += 1 + _, l_report, l_errors = self.process_file(input_path, output_path) + tree_errors += l_errors + report += l_report + for input_path, output_path in files_to_copy: + output_directory = os.path.dirname(output_path) + if not os.path.isdir(output_directory): + os.makedirs(output_directory) + shutil.copy(input_path, output_path) + return file_count, report, tree_errors + + +class TFAPIChangeSpec(APIChangeSpec): + """List of maps that describe what changed in the API.""" + + def __init__(self): + # Maps from a function name to a dictionary that describes how to + # map from an old argument keyword to the new argument keyword. + self.function_keyword_renames = { + "tf.batch_matmul": { + "adj_x": "adjoint_a", + "adj_y": "adjoint_b", + }, + "tf.count_nonzero": { + "reduction_indices": "axis" + }, + "tf.reduce_all": { + "reduction_indices": "axis" + }, + "tf.reduce_any": { + "reduction_indices": "axis" + }, + "tf.reduce_max": { + "reduction_indices": "axis" + }, + "tf.reduce_mean": { + "reduction_indices": "axis" + }, + "tf.reduce_min": { + "reduction_indices": "axis" + }, + "tf.reduce_prod": { + "reduction_indices": "axis" + }, + "tf.reduce_sum": { + "reduction_indices": "axis" + }, + "tf.reduce_logsumexp": { + "reduction_indices": "axis" + }, + "tf.expand_dims": { + "dim": "axis" + }, + "tf.argmax": { + "dimension": "axis" + }, + "tf.argmin": { + "dimension": "axis" + }, + "tf.reduce_join": { + "reduction_indices": "axis" + }, + "tf.sparse_concat": { + "concat_dim": "axis" + }, + "tf.sparse_split": { + "split_dim": "axis" + }, + "tf.sparse_reduce_sum": { + "reduction_axes": "axis" + }, + "tf.reverse_sequence": { + "seq_dim": "seq_axis", + "batch_dim": "batch_axis" + }, + "tf.sparse_reduce_sum_sparse": { + "reduction_axes": "axis" + }, + "tf.squeeze": { + "squeeze_dims": "axis" + }, + "tf.split": { + "split_dim": "axis", + "num_split": "num_or_size_splits" + }, + "tf.concat": { + "concat_dim": "axis" + }, + } + + # Mapping from function to the new name of the function + self.function_renames = { + "tf.inv": "tf.reciprocal", + "tf.contrib.deprecated.scalar_summary": "tf.summary.scalar", + "tf.contrib.deprecated.histogram_summary": "tf.summary.histogram", + "tf.listdiff": "tf.setdiff1d", + "tf.list_diff": "tf.setdiff1d", + "tf.mul": "tf.multiply", + "tf.neg": "tf.negative", + "tf.sub": "tf.subtract", + "tf.train.SummaryWriter": "tf.summary.FileWriter", + "tf.scalar_summary": "tf.summary.scalar", + "tf.histogram_summary": "tf.summary.histogram", + "tf.audio_summary": "tf.summary.audio", + "tf.image_summary": "tf.summary.image", + "tf.merge_summary": "tf.summary.merge", + "tf.merge_all_summaries": "tf.summary.merge_all", + "tf.image.per_image_whitening": "tf.image.per_image_standardization", + "tf.all_variables": "tf.global_variables", + "tf.VARIABLES": "tf.GLOBAL_VARIABLES", + "tf.initialize_all_variables": "tf.global_variables_initializer", + "tf.initialize_variables": "tf.variables_initializer", + "tf.initialize_local_variables": "tf.local_variables_initializer", + "tf.batch_matrix_diag": "tf.matrix_diag", + "tf.batch_band_part": "tf.band_part", + "tf.batch_set_diag": "tf.set_diag", + "tf.batch_matrix_transpose": "tf.matrix_transpose", + "tf.batch_matrix_determinant": "tf.matrix_determinant", + "tf.batch_matrix_inverse": "tf.matrix_inverse", + "tf.batch_cholesky": "tf.cholesky", + "tf.batch_cholesky_solve": "tf.cholesky_solve", + "tf.batch_matrix_solve": "tf.matrix_solve", + "tf.batch_matrix_triangular_solve": "tf.matrix_triangular_solve", + "tf.batch_matrix_solve_ls": "tf.matrix_solve_ls", + "tf.batch_self_adjoint_eig": "tf.self_adjoint_eig", + "tf.batch_self_adjoint_eigvals": "tf.self_adjoint_eigvals", + "tf.batch_svd": "tf.svd", + "tf.batch_fft": "tf.fft", + "tf.batch_ifft": "tf.ifft", + "tf.batch_fft2d": "tf.fft2d", + "tf.batch_ifft2d": "tf.ifft2d", + "tf.batch_fft3d": "tf.fft3d", + "tf.batch_ifft3d": "tf.ifft3d", + "tf.select": "tf.where", + "tf.complex_abs": "tf.abs", + "tf.batch_matmul": "tf.matmul", + "tf.pack": "tf.stack", + "tf.unpack": "tf.unstack", + "tf.op_scope": "tf.name_scope", + } + + self.change_to_function = { + "tf.ones_initializer", + "tf.zeros_initializer", + } + + # Functions that were reordered should be changed to the new keyword args + # for safety, if positional arguments are used. If you have reversed the + # positional arguments yourself, this could do the wrong thing. + self.function_reorders = { + "tf.split": ["axis", "num_or_size_splits", "value", "name"], + "tf.sparse_split": ["axis", "num_or_size_splits", "value", "name"], + "tf.concat": ["concat_dim", "values", "name"], + "tf.svd": ["tensor", "compute_uv", "full_matrices", "name"], + "tf.nn.softmax_cross_entropy_with_logits": [ + "logits", "labels", "dim", "name" + ], + "tf.nn.sparse_softmax_cross_entropy_with_logits": [ + "logits", "labels", "name" + ], + "tf.nn.sigmoid_cross_entropy_with_logits": ["logits", "labels", "name"], + "tf.op_scope": ["values", "name", "default_name"], + } + + # Specially handled functions. + self.function_handle = {"tf.reverse": self._reverse_handler} + + @staticmethod + def _reverse_handler(file_edit_recorder, node): + # TODO(aselle): Could check for a literal list of bools and try to convert + # them to indices. + comment = ("ERROR: tf.reverse has had its argument semantics changed " + "significantly the converter cannot detect this reliably, so " + "you need to inspect this usage manually.\n") + file_edit_recorder.add( + comment, + node.lineno, + node.col_offset, + "tf.reverse", + "tf.reverse", + error="tf.reverse requires manual check.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Convert a TensorFlow Python file to 1.0 + +Simple usage: + tf_convert.py --infile foo.py --outfile bar.py + tf_convert.py --intree ~/code/old --outtree ~/code/new +""") + parser.add_argument( + "--infile", + dest="input_file", + help="If converting a single file, the name of the file " + "to convert") + parser.add_argument( + "--outfile", + dest="output_file", + help="If converting a single file, the output filename.") + parser.add_argument( + "--intree", + dest="input_tree", + help="If converting a whole tree of files, the directory " + "to read from (relative or absolute).") + parser.add_argument( + "--outtree", + dest="output_tree", + help="If converting a whole tree of files, the output " + "directory (relative or absolute).") + parser.add_argument( + "--copyotherfiles", + dest="copy_other_files", + help=("If converting a whole tree of files, whether to " + "copy the other files."), + type=bool, + default=False) + parser.add_argument( + "--reportfile", + dest="report_filename", + help=("The name of the file where the report log is " + "stored." + "(default: %(default)s)"), + default="report.txt") + args = parser.parse_args() + + upgrade = ASTCodeUpgrader(TFAPIChangeSpec()) + report_text = None + report_filename = args.report_filename + files_processed = 0 + if args.input_file: + files_processed, report_text, errors = upgrade.process_file( + args.input_file, args.output_file) + files_processed = 1 + elif args.input_tree: + files_processed, report_text, errors = upgrade.process_tree( + args.input_tree, args.output_tree, args.copy_other_files) + else: + parser.print_help() + if report_text: + open(report_filename, "w").write(report_text) + print("TensorFlow 1.0 Upgrade Script") + print("-----------------------------") + print("Converted %d files\n" % files_processed) + print("Detected %d errors that require attention" % len(errors)) + print("-" * 80) + print("\n".join(errors)) + print("\nMake sure to read the detailed log %r\n" % report_filename) |
