Driver

Driver is the Squirrel component that is used for accessing data. Combined with the IterStream functionalities, drivers provide a powerful and intuitive way of accessing data:

from torch.utils.data import DataLoader

url = "path/to/my/messagepack/dataset"
driver = MessagepackDriver(url)  # a driver that reads messagepack-serialized data
train_data = (
    driver.get_iter()  # returns a Composable, i.e. an iterable with 'iterstream' powers
    .filter(lambda dct: not dct["is_bad_sample"])  # skip unwanted samples
    .async_map(augment_image)  # possible to use a thread/process pool, or run on dask
    .batched(size=100)
    .compose(TorchIterable)  # ready for training!
)

train_loader = DataLoader(train_data, batch_size=None)
# ... have fun training your model

You can pass storage_options to any driver to customize storage backend. Drivers differ in the way they provide access to data.

IterDriver

Most drivers provide a way to iterate over the parts of the underlying data source. Such drivers inherit from the IterDriver base class and their get_iter() method returns an iterable of these parts. Semantically, the “parts” are dataset-dependent and can be anything: a single sentence in a text corpus, or a single row in a csv file, or a single (image, label) pair in an image classification dataset.

Let’s see an IterDriver in action:

import tempfile

from squirrel.driver import IterDriver
from squirrel.iterstream import Composable, IterableSource


class MyDriver(IterDriver):
    """Driver that loads lines of a text file."""

    name = "my_iter_driver"

    def __init__(self, txt_path: str):
        self.txt_path = txt_path

    def get_iter(self) -> Composable:
        with open(self.txt_path, "r") as f:
            return IterableSource(line.strip() for line in f.readlines())

Note

It is required to define the name class variable if this driver is intended to be registered with a source in a Catalog. When loading the driver of a source (via the get_driver() method), the driver name defined in the source is checked against the names of all available drivers to find the target driver.

To see how you can register your custom driver so that it can be used with a Catalog, refer to the Plugin Tutorial.

# prepare a text "corpus" and read from it
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
    f.write("Lorem ipsum dolor sit amet\n")
    f.write("consetetur sadipscing elitr\n")
    f.flush()

    driver = MyDriver(f.name)
    lines = driver.get_iter().collect()
    assert len(lines) == 2
    assert lines[0] == "Lorem ipsum dolor sit amet"
    assert lines[1] == "consetetur sadipscing elitr"

    # get_iter() returns a squirrel.iterstream.Composable, we can use iterstream functionalities directly (actually
    # we already were using collect() above)

    upper_lines = driver.get_iter().map(str.upper).collect()
    assert upper_lines[0] == "LOREM IPSUM DOLOR SIT AMET"
    assert upper_lines[1] == "CONSETETUR SADIPSCING ELITR"

MapDriver

Some data sources inherently has a (key, value) mapping between data parts and some keys identifying these data parts. Maybe our csv file has an index column that is unique for each row, or our image dataset consists of separate image files (which are identified by their file names). In such cases, given a key, it is possible to retrieve the corresponding dataset part. Squirrel provides the MapDriver base class for this use case:

import tempfile
import typing as t

import pandas as pd

from squirrel.driver import MapDriver


class MyDriver(MapDriver):

    name = "my_map_driver"

    def __init__(self, csv_path: str, index_col: str):
        self.csv_path = csv_path
        self.df = pd.read_csv(csv_path, index_col=index_col)

    def get(self, key: str) -> t.Dict:
        return self.df.loc[key].to_dict()

    def keys(self) -> t.Iterator[str]:
        yield from self.df.index


with tempfile.TemporaryDirectory() as tmp_dir:
    df = pd.get_dummies(list("abca"))
    csv_path = f"{tmp_dir}/dummy.csv"
    df.to_csv(csv_path, index_label="index")

    driver = MyDriver(csv_path, index_col="index")
    sample = driver.get(0)
    assert sample["a"] == 1
    assert sample["b"] == 0
    assert sample["c"] == 0

Even though we only implement the get() and keys() methods, it is possible to call get_iter() as well. When called, MapDriver takes the keys iterable from keys() and will call get() for each key. See the method reference for more details.

For this simple example, a custom driver works well. In general, it is better to use the CsvDriver with .csv files.

StoreDriver

For common data access scenarios, it is much simpler to delegate low-level data operations to a Store. StoreDriver lets the underlying store to handle get() and keys() calls.

For example, MessagepackDriver can load messagepack-serialized data by using the SquirrelStore behind the scenes.

FileDriver

FileDriver can be used to access individual files. Let’s save and reload a torch model using FileDriver:

import tempfile

import torch
import torch.nn as nn
import torch.nn.functional as F

from squirrel.driver.file import FileDriver


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))


my_model = Model()

with tempfile.TemporaryDirectory() as temp_dir:
    # trace your model to TorchScript and save using FileDriver
    model_path  = f"{temp_dir}/my_model.pt"

    with FileDriver(model_path).open(mode='wb', create_if_not_exists=True) as f:
        my_scripted_model = torch.jit.script(my_model)
        torch.jit.save(my_scripted_model, f)

    # now, load the model back
    with FileDriver(model_path).open(mode='rb') as f:
        model_reloaded = torch.jit.load(f)

    # test that model outputs are the same
    batch = torch.rand(16, 1, 100,100)
    assert torch.equal(my_model(batch), model_reloaded(batch))

Further reading

Drivers can be registered as part of a Source in a Catalog.

squirrel-datasets provides drivers to load data from various datasets.