import math
import torch
from torch.utils.data import TensorDataset, Dataset
import random
[docs]def train_valid_splitter(x, y, split, shuffle=True):
""" Generate training and validation tensors from whole dataset data and label tensors
:param x: Data tensor for whole dataset
:type x: torch.Tensor
:param y: Label tensor for whole dataset
:type y: torch.Tensor
:param split: Fraction of dataset to be used for validation
:type split: float
:param shuffle: If True randomize tensor order before splitting else do not randomize
:type shuffle: bool
:return: Training and validation tensors (training data, training labels, validation data, validation labels)
:rtype: tuple
"""
num_samples_x = x.size()[0]
num_valid_samples = math.floor(num_samples_x * split)
if shuffle:
indicies = torch.randperm(num_samples_x)
x, y = x[indicies], y[indicies]
x_val, y_val = x[:num_valid_samples], y[:num_valid_samples]
x, y = x[num_valid_samples:], y[num_valid_samples:]
return x, y, x_val, y_val
[docs]def get_train_valid_sets(x, y, validation_data, validation_split, shuffle=True):
""" Generate validation and training datasets from whole dataset tensors
:param x: Data tensor for dataset
:type x: torch.Tensor
:param y: Label tensor for dataset
:type y: torch.Tensor
:param validation_data: Optional validation data (x_val, y_val) to be used instead of splitting x and y tensors
:type validation_data: (torch.Tensor, torch.Tensor)
:param validation_split: Fraction of dataset to be used for validation
:type validation_split: float
:param shuffle: If True randomize tensor order before splitting else do not randomize
:type shuffle: bool
:return: Training and validation datasets
:rtype: tuple
"""
valset = None
if validation_data is not None:
x_val, y_val = validation_data
elif isinstance(validation_split, float):
x, y, x_val, y_val = train_valid_splitter(x, y, validation_split, shuffle=shuffle)
else:
x_val, y_val = None, None
trainset = TensorDataset(x, y)
if x_val is not None and y_val is not None:
valset = TensorDataset(x_val, y_val)
return trainset, valset
[docs]class DatasetValidationSplitter:
def __init__(self, dataset_len, split_fraction, shuffle_seed=None):
""" Generates training and validation split indicies for a given dataset length and creates training and
validation datasets using these splits
:param dataset_len: The length of the dataset to be split into training and validation
:param split_fraction: The fraction of the whole dataset to be used for validation
:param shuffle_seed: Optional random seed for the shuffling process
"""
super().__init__()
self.dataset_len = dataset_len
self.split_fraction = split_fraction
self.valid_ids = None
self.train_ids = None
self._gen_split_ids(shuffle_seed)
def _gen_split_ids(self, seed):
all_ids = list(range(self.dataset_len))
if seed is not None:
random.seed(seed)
random.shuffle(all_ids)
num_valid_ids = math.floor(self.dataset_len*self.split_fraction)
self.valid_ids = all_ids[:num_valid_ids]
self.train_ids = all_ids[num_valid_ids:]
[docs] def get_train_dataset(self, dataset):
""" Creates a training dataset from existing dataset
:param dataset: Dataset to be split into a training dataset
:type dataset: torch.utils.data.Dataset
:return: Training dataset split from whole dataset
:rtype: torch.utils.data.Dataset
"""
return SubsetDataset(dataset, self.train_ids)
[docs] def get_val_dataset(self, dataset):
""" Creates a validation dataset from existing dataset
:param dataset: Dataset to be split into a validation dataset
:type dataset: torch.utils.data.Dataset
:return: Validation dataset split from whole dataset
:rtype: torch.utils.data.Dataset
"""
return SubsetDataset(dataset, self.valid_ids)
class SubsetDataset(Dataset):
def __init__(self, dataset, ids):
""" Dataset that consists of a subset of a previous dataset
:param dataset: Complete dataset
:type dataset: torch.utils.data.Dataset
:param ids: List of subset IDs
:type ids: list
"""
super().__init__()
self.dataset = dataset
self.ids = ids
def __getitem__(self, index):
return self.dataset.__getitem__(self.ids[index])
def __len__(self):
return len(self.ids)