Store¶
Store
manages the storage and retrieval of data and serve as an abstraction layer under
StoreDriver
to ease the implementation of custom drivers.
Squirrel store API defines three methods:
Store.set()
: Used to store a value with a key.Store.get()
: Used to retrieve a previously stored value.Store.keys()
: Returns all the keys for which the store has a value.
Store vs Driver
A Store permits persisting of values via the set()
method whereas a Driver
can only read from a data source and cannot write to it.
If you only want to load data, you can:
Use one of the Squirrel Datasets drivers. For this, your dataset must be in a supported format and structure.
Implement a custom driver. By implementing your own driver, you can make use of the Squirrel API in a way that suits you best. Implementing a driver is easy, just have a look at the custom IterDriver and MapDriver implementations.
If you need to rewrite your data after processing/transforming it, you should use a Store. We recommend using the SquirrelStore, since it comes with performance benefits as well as serialization and sharding support. Keep reading this page to get more information about SquirrelStore.
SquirrelStore¶
SquirrelStore
is the recommended store to use with squirrel. It comes with several optimizations to
improve read/write speed and reduce storage size.
With SquirrelStore, it is possible to:
Save shards (i.e. a collection of samples) in the store and retrieve them fast (see Performance Benchmark)
Serialize shards using a
SquirrelSerializer
instance
A Store can be initialized as below:
import tempfile
from squirrel.serialization import MessagepackSerializer
from squirrel.store import SquirrelStore
tmpdir = tempfile.TemporaryDirectory()
msg_store = SquirrelStore(url=tmpdir.name, serializer=MessagepackSerializer())
You can get an instance of a store from driver too. This is the recommended approach, unless low-level control is needed.
from squirrel.driver import MessagepackDriver
driver = MessagepackDriver(tmpdir.name)
store = driver.store
Sample and Shard¶
squirrel.store.SquirrelStore
uses a concept called sharding to efficiently store and load data.
A Shard is a collection of samples, it stores a predetermined number of samples in a fixed order.
Samples can be any Python object.
They represent a single training sample for model training and can be for example a Dictionary containing a numpy array.
Each shard is then identified through a unique key.
A sample is of type Dict[str, Any]
and a shard is a list thereof i.e. List[Dict[str, Any]]
.
Writing samples as shards using SquirrelStore¶
Approach 1: Write/read shards sequentially¶
import numpy as np
def get_sample(i):
return {
"image": np.random.random((3, 3, 3)),
"label": np.random.choice([1, 2]),
"metadata": {"key": i},
}
N_SAMPLES, N_SHARDS = 100, 10
samples = [get_sample(i) for i in range(N_SAMPLES)]
shards = [samples[i : i + 10] for i in range(N_SHARDS)]
Shards can be saved by using the set() method.
for i, shard in enumerate(shards):
store.set(
shard,
key=f"shard_{i}", # dont need to set key, if omitted, a random key will be used
)
assert len(list(store.keys())) == N_SHARDS
Let’s check out a sample:
for key in store.keys():
shard = store.get(key)
for sample in shard:
print(sample)
break
break
# Clean up
tmpdir.cleanup()
Approach 2: Write/read shards asynchronously using iterstream¶
SquirrelStore does not buffer any data, as soon as set() is called, the data is written to the store. Because of this, writing to the store can be easily parallelized. In the following example, we use async_map from Iterstream module to write shards to the store in parallel and also read from the store in parallel.
from squirrel.iterstream import IterableSource
tmpdir = tempfile.TemporaryDirectory()
store = MessagepackDriver(tmpdir.name).store
# note that we are not providing keys for the shards here, random keys will be used
IterableSource(shards).async_map(store.set).join()
assert len(list(store.keys())) == 10
samples = IterableSource(store.keys()).async_map(store.get).flatten().collect()
assert len(samples) == 100
# Clean up
tmpdir.cleanup()