diff options
| author | Jun-Yan Zhu <junyanz@users.noreply.github.com> | 2017-07-04 10:08:46 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-07-04 10:08:46 -0700 |
| commit | e77d1352c0618adf8abf348b04647dd86e8890c1 (patch) | |
| tree | 90f76c58d41db6015a302cba9e4406f52c1dcf20 /util | |
| parent | 233630e79d79901faff420eb0ae481b35d952f97 (diff) | |
| parent | 06ff0d70ed118691fb906f815e78ad20902b90c7 (diff) | |
Merge pull request #56 from TariqAHassan/get_data
Add Tools to Easily Obtain CycleGAN or Pix2Pix Data
Diffstat (limited to 'util')
| -rw-r--r-- | util/get_data.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/util/get_data.py b/util/get_data.py new file mode 100644 index 0000000..6325605 --- /dev/null +++ b/util/get_data.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """ + + Download CycleGAN or Pix2Pix Data. + + Args: + technique : str + One of: 'cyclegan' or 'pix2pix'. + verbose : bool + If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Args: + save_path : str + A directory to save the data to. + dataset : str, optional + A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full : str + The absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) |
