IterStream

Squirrel provides an API for chaining iterables. The functionality is provided by IterableSource (squirrel.iterstream.source.IterableSource).

Example Workflow

from squirrel.iterstream import IterableSource
import time

it = IterableSource([1, 2, 3, 4])
for item in it:
    print(item)

IterableSource has several methods to conveniently load data, given an iterable as the input:

it = IterableSource([1, 2, 3, 4]).map(lambda x: x + 1).async_map(lambda x: x ** 2).filter(lambda x: x % 2 == 0)
for item in it:
    print(item)

map_async() applies the provided function asynchronously. More on this in the following sections. In addition to explicitly iterating over the items, it’s also possible to call collect() to collect all items in a list, or join() to iterate over items without returning anything.

Items in the stream can be shuffled in the buffer and batched

it = IterableSource(range(10)).shuffle(size=5).map(lambda x: x+1).batched(batchsize=3, drop_last_if_not_full=True)
for item in it:
    print(item)

Note that the argument drop_last_if_not_full (default True) will drop the last batch if its size is less than batchsize argument; so, only 3 items will be printed above.

Items in IterableSource can be composed by providing a Composable in the compose() method:

from squirrel.iterstream import Composable

class MyIter(Composable):
    def __init__(self):
        super().__init__()

    def __iter__(self):
        for i in iter(self.source):
            yield f"_{i}", i

it = IterableSource([1, 2, 3]).compose(MyIter)
for item in it:
    print(item)

Combining multiple iterables can be achieved using IterableSamplerSource:

from squirrel.iterstream import IterableSamplerSource

it1 = IterableSource([1, 2, 3]).map(lambda x: x + 1)
it2 = [1, 2, 3]

res = IterableSamplerSource(iterables=[it1, it2], probs=[.7, .3]).collect()
print(res)
assert sum(res) == 15

Note that you can pass the probabilities of sampling from each iterator. When an iterator is exhausted, the probabilities are normalized.

Asynchronous execution

Part of the fast speed from iterstream thanks to squirrel.iterstream.base.Composable.async_map(). This method carries out the callback function you specified to each item in the stream asynchronously, therefore offers a large speed-up.

from concurrent.futures import ThreadPoolExecutor
tpool =  ThreadPoolExecutor()

def io_bound(item):
    print(f"{item} io_bound")
    time.sleep(1)
    return item

it = IterableSource([1, 2, 3]).async_map(io_bound, executor=tpool).async_map(io_bound)
t1 = time.time()
for i in it:
    print(i)
print(time.time() - t1)

async_map instantiates a concurrent.futures.ThreadPoolExecutor if the argument executor is None (default). It also accepts concurrent.futures.ProcessPoolExecutor, which is a good choice when performing cpu-bound operations on a single machine.

Cluster-mode

Scaling out to a dask cluster only requires changing a single line of code:

from dask.distributed import Client
client = Client()

it = IterableSource([1, 2, 3]).async_map(io_bound, executor=client)
t1 = time.time()
for item in it:
    print(item)
print(time.time() - t1)

In this example, a task is submitted and the result is gathered. An alternative would be to call dask_map instead of async_map, which transforms items in the stream into dask.delayed.Delayed objects. This pattern makes it possible to load and transform the data in a dask cluster and only load the fully ready data into the local machine.

it = IterableSource([1, 2, 3]).dask_map(io_bound).dask_map(lambda item: item + 1).materialize_dask()
t1 = time.time()
for item in it:
    print(item)
print(time.time() - t1)

Note that after calling dask_map for the first time, you can chain more dask_map`s, which are then operating on the :code:`dask.delayed.Delayed objects, so that the data and the operations live on the dask cluster until materialize_dask is called.

Just-in-time compilation with numba

Squirrel uses numba to jit-compile an iterator to speed up computation in the main process.

it = IterableSource([1, 2, 3]).numba_map(lambda x: x + 1)
for item in it:
    print(item)

Here, the iterator itself is passed to the numba decorator @numba.jit. Then the speed-up will be entirely provided by numba. Note that squirrel does not compile the user defined function. You may achieve a comparable speed-up by compiling your function and passing it to map():

from numba import jit

@jit(nopython=True)
def runtime_transformation(x):
    return x

it = IterableSource([1, 2, 3]).map(runtime_transformation)
for item in it:
    print(item)

Compared to the other three options, numba is more performant in some cases but not in others, and highly sensitive to the actual data type and computation at hand. Therefore, we recommend you read the official numba documentation from numba, and perform a benchmarking, before choosing this option in production.

Note

Since numba only supports limited types of python objects, and naturally does not include squirrel defined objects, we have to force object mode in numba, that means the decorator we have chosen in squirrel is of the following format: @numba.jit(forceobj=True).

PyTorch Distributed Dataloading

The Squirrel api is designed to support fast streaming of datasets to a multi-rank, distributed system, as often encountered in modern deep learning applications involving multiple GPUs. To this end, we can use the SplitByWorker and SplitByRank composables and wrap the final iterator in a torch Dataloader object

import torch.utils.data as tud
from squirrel.iterstream.source import IterableSource
from squirrel.iterstream.torch_composables import SplitByRank, SplitByWorker, TorchIterable

def times_two(x: float) -> float:
    return x * 2

samples = list(range(100))
batch_size = 5
num_workers = 4
it = (
        IterableSource(samples)
        .compose(SplitByRank)
        .async_map(times_two)
        .compose(SplitByWorker)
        .batched(batch_size)
        .compose(TorchIterable)
    )
dl = tud.DataLoader(it, num_workers=num_workers)

Note that the rank of the distributed system depends on the torch distributed process group and is automatically determined.

And using squirrel.driver api:

from squirrel.driver import MessagepackDriver
url = ""
it = MessagepackDriver(url).get_iter(key_hooks=[SplitByWorker]).async_map(times_two).batched(batch_size).compose(TorchIterable)
dl = DataLoader(it, num_workers=num_workers)

In this example, key_hooks=[SplitByWorker] ensures that keys are split between workers before fetching the data and we achieve two level of parallelism; multi-processing provided by torch.utils.data.DataLoader, and multi-threading inside each process for efficiently fetching samples by get_iter.

Performance Monitoring

In squirrel, performance in iterstream can be calculated and logged. This is done by applying an extra method monitor() into the original chaining iterstream. It can be added into any step in the above example where it is defined. For example, you can add .monitor(callback=wandb.log) right after async_map(times_two) Then the performance of all the previous steps combined will be calculated at this point and the calculated metrics will be passed to any user-specified callback such as wandb.log().

The following is a complete example:

import wandb
import mlflow
import numpy as np

def times_two(x: float) -> float:
    return x * 2

samples = [np.random.rand(10, 10) for i in range(10 ** 4)]
batch_size = 5

with wandb.init(): # or mlflow.start_run()
    it = (
        IterableSource(samples)
        .async_map(times_two)
        .monitor(wandb.log) # or mlflow.log_metrics
        .batched(batch_size)
    )
    it.collect() # or it.take(<some int>).join()

This will create an iterstream with the same transformation logics as it was without the method monitor, but the calculated metrics at step async_map is sent to the callback function wandb.log. (The calculated metrics is of type Dict[str, [int, float]], therefore any function takes such argument can be used to plug into the callback of monitor.)

By default, monitor calculate two metrics: IOPS and throughput. However, this can be configured by passing a data class squirrel.metrics.MetricsConf to the argument metrics_conf in monitor. For details, see squirrel.iterstream.metrics.

Monitoring at different locations in an iterstream in one run can be achieved by inserting monitor with different prefix:

with wandb.init(): # or mlflow.start_run()
    it = (
        IterableSource(samples)
        .monitor(wandb.log, prefix="(before async_map) ")
        .async_map(times_two)
        .monitor(wandb.log, prefix="(after async_map) ") # or mlflow.log_metrics
        .batched(batch_size)
    )
    it.collect() # or it.take(<some int>).join()

This will generate 4 instead of 2 metrics with each original metric bifurcate into two with different prefixes to track at which point the metrics are generated. (This does not interfere with metrics_conf which determines which metrics should be used in each monitor.)