Source code for torchbearer.variational.datasets

import os
import shutil
import zipfile

import numpy as np
from PIL import Image
from torch.utils.data import Dataset


[docs]def make_dataset(dir, extensions): from torchvision.datasets.folder import has_file_allowed_extension images = [] for root, _, fnames in sorted(os.walk(dir)): for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = path images.append(item) return images
[docs]class SimpleImageFolder(Dataset): def __init__(self, root, loader=None, extensions=None, transform=None, target_transform=None): """ Simple image folder dataset that loads all images from inside a folder and returns items in (image, image) tuple Args: root (str): Root directory of dataset containing all aligned images loader (function, optional): Image loader function that takes a file or path and returns the loaded image (see torchvision.datasets.folder) extensions (:obj:`list` of :obj:`str`, optional): List of file extensions that can be loaded transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS loader = default_loader if loader is None else loader extensions = IMG_EXTENSIONS if extensions is None else extensions samples = make_dataset(root, extensions) self.root = root self.loader = loader self.extensions = extensions self.samples = samples self.transform = transform self.target_transform = target_transform def __getitem__(self, index): """ Args: index (int): Index of image Returns: tuple: (sample, target) where target is target transformed image. """ path = self.samples[index] sample = self.loader(path) input_sample, target_sample = sample, sample if self.transform is not None: input_sample = self.transform(sample) if self.target_transform is not None: target_sample = self.target_transform(sample) return input_sample, target_sample def __len__(self): return len(self.samples)
[docs]class CelebA(SimpleImageFolder): def __init__(self, root, transform=None, target_transform=None): """ `CelebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset Args: root (str): Root directory of dataset containing all aligned images in 'root' transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) def __getitem__(self, index): item = super(CelebA, self).__getitem__(index) return item
[docs]class CelebA_HQ(SimpleImageFolder): def __init__(self, root, as_npy=False, transform=None): """ CelebA_HQ, high quality version of `celebA <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ auto-encoding dataset as introduced by `Progressive GAN <https://arxiv.org/abs/1710.10196>`_ Args: root (str): Root directory of dataset containing all hq images in 'root' as_npy (bool, optional): If True, assume images are stored in numpy arrays. Else assume a standard image format transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS if as_npy: loader = self.npy_loader extensions = ['npy'] else: loader = default_loader extensions = IMG_EXTENSIONS super(CelebA_HQ, self).__init__(root, loader, extensions, transform)
[docs] @staticmethod def npy_loader(path): img = np.load(path)[0].transpose([1,2,0]) pil_image = Image.fromarray(img) return pil_image
def __getitem__(self, index): item = super(CelebA_HQ, self).__getitem__(index) return item
[docs]class dSprites(Dataset): def __init__(self, root, download=False, transform=None): """ `dSprites <https://github.com/deepmind/dsprites-dataset>`_ Dataset Args: root (str): Root directory of dataset containing 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' or to download it to download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (``Transform``, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ super(dSprites, self).__init__() self.file = root self.transform = transform if download: self.download() self.data = self.load_data() self.latents_sizes = np.array([1, 3, 6, 40, 32, 32]) self.latents_bases = np.concatenate((self.latents_sizes[::-1].cumprod()[::-1][1:], np.array([1, ]))) self.latents_values = np.load(os.path.join(self.file, "latents_values.npy")) self.latents_classes = np.load(os.path.join(self.file, "latents_classes.npy"))
[docs] def download(self): if not os.path.exists(os.path.join(self.file, "imgs.npy")): data_url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true' import sys if sys.version_info[0] < 3: import urllib2 as request else: import urllib.request as request file = os.path.join(self.file, "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") os.makedirs(self.file, exist_ok=True) with request.urlopen(data_url) as response, open(file, 'wb+') as out_file: shutil.copyfileobj(response, out_file) zip_ref = zipfile.ZipFile(file, 'r') zip_ref.extractall(self.file) zip_ref.close()
[docs] def get_img_by_latent(self, latent_code): """ Returns the image defined by the latent code Args: latent_code (:obj:`list` of :obj:`int`): Latent code of length 6 defining each generative factor Returns: Image defined by given code """ def latent_to_index(latents): return np.dot(latents, self.latents_bases).astype(int) idx = latent_to_index(latent_code) return self.__getitem__(idx)[0]
[docs] def load_data(self): root = os.path.join(self.file, "imgs.npy") data = np.load(root) return data
def __getitem__(self, index): data = self.data[index] data = Image.fromarray(data * 255, mode='L') if self.transform is not None: data = self.transform(data) return data, data def __len__(self): return self.data.shape[0]