Store

Store manages the storage and retrieval of data and serve as an abstraction layer under StoreDriver to ease the implementation of custom drivers.

Squirrel store API defines three methods:

  • Store.set(): Used to store a value with a key.

  • Store.get(): Used to retrieve a previously stored value.

  • Store.keys(): Returns all the keys for which the store has a value.

Store vs Driver

A Store permits persisting of values via the set() method whereas a Driver can only read from a data source and cannot write to it.

If you only want to load data, you can:

  • Use one of the Squirrel Datasets drivers. For this, your dataset must be in a supported format and structure.

  • Implement a custom driver. By implementing your own driver, you can make use of the Squirrel API in a way that suits you best. Implementing a driver is easy, just have a look at the custom IterDriver and MapDriver implementations.

If you need to rewrite your data after processing/transforming it, you should use a Store. We recommend using the SquirrelStore, since it comes with performance benefits as well as serialization and sharding support. Keep reading this page to get more information about SquirrelStore.

SquirrelStore

SquirrelStore is the recommended store to use with squirrel. It comes with several optimizations to improve read/write speed and reduce storage size.

With SquirrelStore, it is possible to:

  • Save shards (i.e. a collection of samples) in the store and retrieve them fast (see Performance Benchmark)

  • Serialize shards using a SquirrelSerializer instance

A Store can be initialized as below:

import tempfile

from squirrel.serialization import MessagepackSerializer
from squirrel.store import SquirrelStore

tmpdir = tempfile.TemporaryDirectory()
msg_store = SquirrelStore(url=tmpdir.name, serializer=MessagepackSerializer())

You can get an instance of a store from driver too. This is the recommended approach, unless low-level control is needed.

from squirrel.driver import MessagepackDriver

driver = MessagepackDriver(tmpdir.name)
store = driver.store

Sample and Shard

squirrel.store.SquirrelStore uses a concept called sharding to efficiently store and load data. A Shard is a collection of samples, it stores a predetermined number of samples in a fixed order. Samples can be any Python object. They represent a single training sample for model training and can be for example a Dictionary containing a numpy array. Each shard is then identified through a unique key. A sample is of type Dict[str, Any] and a shard is a list thereof i.e. List[Dict[str, Any]].

Writing samples as shards using SquirrelStore

Approach 1: Write/read shards sequentially

import numpy as np


def get_sample(i):
    return {
        "image": np.random.random((3, 3, 3)),
        "label": np.random.choice([1, 2]),
        "metadata": {"key": i},
    }


N_SAMPLES, N_SHARDS = 100, 10
samples = [get_sample(i) for i in range(N_SAMPLES)]
shards = [samples[i : i + 10] for i in range(N_SHARDS)]

Shards can be saved by using the set() method.

for i, shard in enumerate(shards):
    store.set(
        shard,
        key=f"shard_{i}",  # dont need to set key, if omitted, a random key will be used
    )

assert len(list(store.keys())) == N_SHARDS

Let’s check out a sample:

for key in store.keys():
    shard = store.get(key)
    for sample in shard:
        print(sample)
        break
    break

# Clean up
tmpdir.cleanup()

Approach 2: Write/read shards asynchronously using iterstream

SquirrelStore does not buffer any data, as soon as set() is called, the data is written to the store. Because of this, writing to the store can be easily parallelized. In the following example, we use async_map from Iterstream module to write shards to the store in parallel and also read from the store in parallel.

from squirrel.iterstream import IterableSource

tmpdir = tempfile.TemporaryDirectory()
store = MessagepackDriver(tmpdir.name).store

# note that we are not providing keys for the shards here, random keys will be used
IterableSource(shards).async_map(store.set).join()
assert len(list(store.keys())) == 10

samples = IterableSource(store.keys()).async_map(store.get).flatten().collect()
assert len(samples) == 100

# Clean up
tmpdir.cleanup()