squirrel.iterstream.torch_composables

Module Contents

Classes

SplitByRank

Composable to split data between ranks of a multi-rank loading setup

SplitByWorker

Composable to split data between PyTorch workers of a single rank

TorchIterable

Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset

Functions

skip_k(→ Callable[[Iterable], Iterator])

Returns a callable that takes an iterable and applies a skipping operation on it.

Attributes

logger

squirrel.iterstream.torch_composables.logger
class squirrel.iterstream.torch_composables.SplitByRank(torch_dist_group: Optional[str] = None)

Bases: squirrel.iterstream.Composable

Composable to split data between ranks of a multi-rank loading setup

Init the SplitByRank composable.

__iter__()Iterator

Method to iterate over the source and yield the elements that will be processed by a particular node

class squirrel.iterstream.torch_composables.SplitByWorker

Bases: squirrel.iterstream.Composable

Composable to split data between PyTorch workers of a single rank

Init

__iter__()Iterator

Method to iterate over the source and yield the elements that will be processed by a particular worker

class squirrel.iterstream.torch_composables.TorchIterable(enforce_rank_check: bool = True, enforce_worker_check: bool = True)

Bases: squirrel.iterstream.Composable, torch.utils.data.IterableDataset

Mixin-Composable to have squirrel pipeline inherit from PyTorch IterableDataset

Init

__iter__()Iterator

Method to iterate over the source

squirrel.iterstream.torch_composables.skip_k(rank: int, world_size: int)Callable[[Iterable], Iterator]

Returns a callable that takes an iterable and applies a skipping operation on it.

Parameters
  • rank – int denoting the rank of the distributed training process.

  • world_size – int denoting the full world size.