IterStream¶
Squirrel provides an API for chaining iterables.
The functionality is provided by IterableSource (squirrel.iterstream.source.IterableSource
).
Example Workflow¶
from squirrel.iterstream import IterableSource
import time
it = IterableSource([1, 2, 3, 4])
for item in it:
print(item)
IterableSource
has several methods to conveniently load data, given an iterable as the input:
it = IterableSource([1, 2, 3, 4]).map(lambda x: x + 1).async_map(lambda x: x ** 2).filter(lambda x: x % 2 == 0)
for item in it:
print(item)
map_async()
applies the provided function asynchronously. More on this in the following sections.
In addition to explicitly iterating over the items, it’s also possible to call collect() to collect all items in a list, or join() to iterate over items without returning anything.
Items in the stream can be shuffled in the buffer and batched
it = IterableSource(range(10)).shuffle(size=5).map(lambda x: x+1).batched(batchsize=3, drop_last_if_not_full=True)
for item in it:
print(item)
Note that the argument drop_last_if_not_full (default True) will drop the last batch if its size is less than batchsize argument; so, only 3 items will be printed above.
Items in IterableSource can be composed by providing a Composable in the compose() method:
from squirrel.iterstream import Composable
class MyIter(Composable):
def __init__(self):
super().__init__()
def __iter__(self):
for i in iter(self.source):
yield f"_{i}", i
it = IterableSource([1, 2, 3]).compose(MyIter)
for item in it:
print(item)
Combining multiple iterables can be achieved using IterableSamplerSource:
from squirrel.iterstream import IterableSamplerSource
it1 = IterableSource([1, 2, 3]).map(lambda x: x + 1)
it2 = [1, 2, 3]
res = IterableSamplerSource(iterables=[it1, it2], probs=[.7, .3]).collect()
print(res)
assert sum(res) == 15
Note that you can pass the probabilities of sampling from each iterator. When an iterator is exhausted, the probabilities are normalized.
Asynchronous execution¶
Part of the fast speed from iterstream thanks to squirrel.iterstream.base.Composable.async_map()
. This method carries out the callback function you specified to each item in the stream asynchronously, therefore offers a large speed-up.
from concurrent.futures import ThreadPoolExecutor
tpool = ThreadPoolExecutor()
def io_bound(item):
print(f"{item} io_bound")
time.sleep(1)
return item
it = IterableSource([1, 2, 3]).async_map(io_bound, executor=tpool).async_map(io_bound)
t1 = time.time()
for i in it:
print(i)
print(time.time() - t1)
async_map instantiates a concurrent.futures.ThreadPoolExecutor
if the argument executor is None (default). It also accepts concurrent.futures.ProcessPoolExecutor
, which is a good choice when performing cpu-bound operations on a single machine.
Cluster-mode¶
Scaling out to a dask cluster only requires changing a single line of code:
from dask.distributed import Client
client = Client()
it = IterableSource([1, 2, 3]).async_map(io_bound, executor=client)
t1 = time.time()
for item in it:
print(item)
print(time.time() - t1)
In this example, a task is submitted and the result is gathered. An alternative would be to call dask_map
instead of async_map
, which transforms items in the stream into dask.delayed.Delayed
objects. This pattern makes it possible to load and transform the data in a dask cluster and only load the fully ready data into the local machine.
it = IterableSource([1, 2, 3]).dask_map(io_bound).dask_map(lambda item: item + 1).materialize_dask()
t1 = time.time()
for item in it:
print(item)
print(time.time() - t1)
Note that after calling dask_map
for the first time, you can chain more dask_map`s, which are then operating on the :code:`dask.delayed.Delayed
objects, so that the data and the operations live on the dask cluster until materialize_dask
is called.
Just-in-time compilation with numba¶
Squirrel uses numba
to jit-compile an iterator to speed up computation in the main process.
it = IterableSource([1, 2, 3]).numba_map(lambda x: x + 1)
for item in it:
print(item)
Here, the iterator itself is passed to the numba
decorator @numba.jit
. Then the speed-up will be entirely provided by numba
. Note that squirrel does not compile the user defined function. You may achieve a comparable speed-up by compiling your function and passing it to map():
from numba import jit
@jit(nopython=True)
def runtime_transformation(x):
return x
it = IterableSource([1, 2, 3]).map(runtime_transformation)
for item in it:
print(item)
Compared to the other three options, numba
is more performant in some cases but not in others, and highly sensitive to the
actual data type and computation at hand. Therefore, we recommend you read the official numba documentation
from numba
, and perform a benchmarking, before choosing this option in production.
Note
Since numba only supports limited types of python objects, and naturally does not include
squirrel defined objects, we have to force object mode in numba, that means the decorator we have chosen in squirrel
is of the following format: @numba.jit(forceobj=True)
.
PyTorch Distributed Dataloading¶
The Squirrel api is designed to support fast streaming of datasets to a multi-rank, distributed system, as often encountered in modern deep learning applications involving multiple GPUs. To this end, we can use the SplitByWorker and SplitByRank composables and wrap the final iterator in a torch Dataloader object
import torch.utils.data as tud
from squirrel.iterstream.source import IterableSource
from squirrel.iterstream.torch_composables import SplitByRank, SplitByWorker, TorchIterable
def times_two(x: float) -> float:
return x * 2
samples = list(range(100))
batch_size = 5
num_workers = 4
it = (
IterableSource(samples)
.compose(SplitByRank)
.async_map(times_two)
.compose(SplitByWorker)
.batched(batch_size)
.compose(TorchIterable)
)
dl = tud.DataLoader(it, num_workers=num_workers)
Note that the rank of the distributed system depends on the torch distributed process group and is automatically determined.
And using squirrel.driver
api:
from squirrel.driver import MessagepackDriver
url = ""
it = MessagepackDriver(url).get_iter(key_hooks=[SplitByWorker]).async_map(times_two).batched(batch_size).compose(TorchIterable)
dl = DataLoader(it, num_workers=num_workers)
In this example, key_hooks=[SplitByWorker]
ensures that keys are split between workers before fetching the data and we achieve two level of parallelism; multi-processing provided by torch.utils.data.DataLoader
, and multi-threading inside each process for efficiently fetching samples by get_iter
.
Performance Monitoring¶
In squirrel, performance in iterstream
can be calculated and logged. This is done by applying an extra method
monitor()
into the original chaining iterstream. It can be added into any step in the above example where
it
is defined. For example, you can add .monitor(callback=wandb.log)
right after
async_map(times_two)
Then the performance of all the previous steps combined will be calculated at this point
and the calculated metrics will be passed to any user-specified callback such as wandb.log()
.
The following is a complete example:
import wandb
import mlflow
import numpy as np
def times_two(x: float) -> float:
return x * 2
samples = [np.random.rand(10, 10) for i in range(10 ** 4)]
batch_size = 5
with wandb.init(): # or mlflow.start_run()
it = (
IterableSource(samples)
.async_map(times_two)
.monitor(wandb.log) # or mlflow.log_metrics
.batched(batch_size)
)
it.collect() # or it.take(<some int>).join()
This will create an iterstream with the same transformation logics as it was without the method monitor
, but the
calculated metrics at step async_map is sent to the callback function wandb.log. (The calculated metrics is of type
Dict[str, [int, float]]
, therefore any function takes such argument can be used to plug into
the callback of monitor
.)
By default, monitor
calculate two metrics: IOPS and throughput. However, this can be configured by
passing
a data class squirrel.metrics.MetricsConf
to the argument metrics_conf
in monitor
.
For details, see squirrel.iterstream.metrics
.
Monitoring at different locations in an iterstream in one run can be achieved by inserting monitor
with
different prefix:
with wandb.init(): # or mlflow.start_run()
it = (
IterableSource(samples)
.monitor(wandb.log, prefix="(before async_map) ")
.async_map(times_two)
.monitor(wandb.log, prefix="(after async_map) ") # or mlflow.log_metrics
.batched(batch_size)
)
it.collect() # or it.take(<some int>).join()
This will generate 4 instead of 2 metrics with each original metric bifurcate into two with different prefixes to
track at which point the metrics are generated. (This does not interfere with metrics_conf
which determines
which metrics should be used in each monitor
.)