# mypy: allow-untyped-defs import torch from torch import Tensor from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union __all__ = [ "BatchSampler", "RandomSampler", "Sampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", ] T_co = TypeVar('T_co', covariant=True) class Sampler(Generic[T_co]): r"""Base class for all Samplers. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a :meth:`__len__` method that returns the length of the returned iterators. Args: data_source (Dataset): This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it. Example: >>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist() .. note:: The :meth:`__len__` method isn't strictly required by :class:`~torch.utils.data.DataLoader`, but is expected in any calculation involving the length of a :class:`~torch.utils.data.DataLoader`. """ def __init__(self, data_source: Optional[Sized] = None) -> None: if data_source is not None: import warnings warnings.warn("`data_source` argument is not used and will be removed in 2.2.0." "You may still have custom implementation that utilizes it.") def __iter__(self) -> Iterator[T_co]: raise NotImplementedError # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # # Many times we have an abstract class representing a collection/iterable of # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally # implementing a `__len__` method. In such cases, we must make sure to not # provide a default implementation, because both straightforward default # implementations have their issues: # # + `return NotImplemented`: # Calling `len(subclass_instance)` raises: # TypeError: 'NotImplementedType' object cannot be interpreted as an integer # # + `raise NotImplementedError`: # This prevents triggering some fallback behavior. E.g., the built-in # `list(X)` tries to call `len(X)` first, and executes a different code # path if the method is not found or `NotImplemented` is returned, while # raising a `NotImplementedError` will propagate and make the call fail # where it could have used `__iter__` to complete the call. # # Thus, the only two sensible things to do are # # + **not** provide a default `__len__`. # # + raise a `TypeError` instead, which is what Python uses when users call # a method that is not defined on an object. # (@ssnl verifies that this works on at least Python 3.7.) class SequentialSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. Args: data_source (Dataset): dataset to sample from """ data_source: Sized def __init__(self, data_source: Sized) -> None: self.data_source = data_source def __iter__(self) -> Iterator[int]: return iter(range(len(self.data_source))) def __len__(self) -> int: return len(self.data_source) class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify :attr:`num_samples` to draw. Args: data_source (Dataset): dataset to sample from replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. generator (Generator): Generator used in sampling. """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator if not isinstance(self.replacement, bool): raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") if not isinstance(self.num_samples, int) or self.num_samples <= 0: raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") @property def num_samples(self) -> int: # dataset size might change at runtime if self._num_samples is None: return len(self.data_source) return self._num_samples def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator() generator.manual_seed(seed) else: generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: for _ in range(self.num_samples // n): yield from torch.randperm(n, generator=generator).tolist() yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] def __len__(self) -> int: return self.num_samples class SubsetRandomSampler(Sampler[int]): r"""Samples elements randomly from a given list of indices, without replacement. Args: indices (sequence): a sequence of indices generator (Generator): Generator used in sampling. """ indices: Sequence[int] def __init__(self, indices: Sequence[int], generator=None) -> None: self.indices = indices self.generator = generator def __iter__(self) -> Iterator[int]: for i in torch.randperm(len(self.indices), generator=self.generator): yield self.indices[i] def __len__(self) -> int: return len(self.indices) class WeightedRandomSampler(Sampler[int]): r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). Args: weights (sequence) : a sequence of weights, not necessary summing up to one num_samples (int): number of samples to draw replacement (bool): if ``True``, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. generator (Generator): Generator used in sampling. Example: >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] """ weights: Tensor num_samples: int replacement: bool def __init__(self, weights: Sequence[float], num_samples: int, replacement: bool = True, generator=None) -> None: if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ num_samples <= 0: raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}") if not isinstance(replacement, bool): raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}") weights_tensor = torch.as_tensor(weights, dtype=torch.double) if len(weights_tensor.shape) != 1: raise ValueError("weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}") self.weights = weights_tensor self.num_samples = num_samples self.replacement = replacement self.generator = generator def __iter__(self) -> Iterator[int]: rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) yield from iter(rand_tensor.tolist()) def __len__(self) -> int: return self.num_samples class BatchSampler(Sampler[List[int]]): r"""Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler or Iterable): Base sampler. Can be any iterable object batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None: # Since collections.abc.Iterable does not check for `__getitem__`, which # is one way for an object to be an iterable, we don't do an `isinstance` # check here. if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ batch_size <= 0: raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}") if not isinstance(drop_last, bool): raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}") self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self) -> Iterator[List[int]]: # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 if self.drop_last: sampler_iter = iter(self.sampler) while True: try: batch = [next(sampler_iter) for _ in range(self.batch_size)] yield batch except StopIteration: break else: batch = [0] * self.batch_size idx_in_batch = 0 for idx in self.sampler: batch[idx_in_batch] = idx idx_in_batch += 1 if idx_in_batch == self.batch_size: yield batch idx_in_batch = 0 batch = [0] * self.batch_size if idx_in_batch > 0: yield batch[:idx_in_batch] def __len__(self) -> int: # Can only be called if self.sampler has __len__ implemented # We cannot enforce this condition, so we turn off typechecking for the # implementation below. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] if self.drop_last: return len(self.sampler) // self.batch_size # type: ignore[arg-type] else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]