squirrel.iterstream.torch_composables

Module Contents

Classes

SplitByRank

Composable to split data between ranks of a multi-rank loading setup

SplitByWorker

Composable to split data between PyTorch workers of a single rank

TorchIterable

Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset

Functions

skip_k(rank: int, world_size: int) → Callable[[Iterable], Iterator]

Returns a callable that takes an iterable and applies a skipping operation on it.

Attributes

logger

squirrel.iterstream.torch_composables.logger
class squirrel.iterstream.torch_composables.SplitByRank(source: Optional[Iterable] = None, torch_dist_group: Optional[str] = None)

Bases: squirrel.iterstream.base.Composable

Composable to split data between ranks of a multi-rank loading setup

Init the SplitByRank composable.

__iter__(self)Iterator

Method to iterate over the source and yield the elements that will be processed by a particular node

class squirrel.iterstream.torch_composables.SplitByWorker(source: Optional[Iterable] = None)

Bases: squirrel.iterstream.base.Composable

Composable to split data between PyTorch workers of a single rank

Init

__iter__(self)Iterator

Method to iterate over the source and yield the elements that will be processed by a particular worker

class squirrel.iterstream.torch_composables.TorchIterable(source: Optional[Iterable] = None)

Bases: squirrel.iterstream.base.Composable, torch.utils.data.IterableDataset

Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset

Init

__iter__(self)Iterator

Method to iterate over the source

squirrel.iterstream.torch_composables.skip_k(rank: int, world_size: int)Callable[[Iterable], Iterator]

Returns a callable that takes an iterable and applies a skipping operation on it.

Parameters
  • rank – int denoting the rank of the distributed training process.

  • world_size – int denoting the full world size.