PyTorchΒΆ
The Squirrel api is designed to support fast streaming of datasets to a multi-rank, distributed system, as often
encountered in modern deep learning applications involving multiple GPUs.
To this end, we can use split_by_rank_pytorch()
and split_by_worker_pytorch()
and wrap the final
iterator in a torch Dataloader object
import torch.utils.data as tud
from squirrel.iterstream.source import IterableSource
def times_two(x: float) -> float:
return x * 2
samples = list(range(100))
batch_size = 5
num_workers = 4
it = (
IterableSource(samples)
.split_by_rank_pytorch()
.async_map(times_two)
.split_by_worker_pytorch()
.batched(batch_size)
.to_torch_iterable()
)
dl = tud.DataLoader(it, num_workers=num_workers)
Note that the rank of the distributed system depends on the torch distributed process group and is automatically determined.
Note
split_by_rank_pytorch()
, split_by_worker_pytorch()
and to_torch_iterable()
are simply convenience functions to chain your iterator with PyTorch specific iterators.
These are implemented as specific Composables.
An example of such a PyTorch specific Composable is given below through SplitByWorker
.
To see how to chain Composables, see Custom Composable.
And using squirrel.driver
api:
from squirrel.driver import MessagepackDriver
url = ""
it = MessagepackDriver(url).get_iter(key_hooks=[SplitByWorker]).async_map(times_two).batched(batch_size).compose(TorchIterable)
dl = DataLoader(it, num_workers=num_workers)
In this example, key_hooks=[SplitByWorker]
ensures that keys are split between workers before fetching the data
and we achieve two level of parallelism; multi-processing provided by torch.utils.data.DataLoader
, and
multi-threading inside each process for efficiently fetching samples by get_iter
.