squirrel.iterstream.multiplexer

Multiplexing module to combine and sample from a set of composables.

Module Contents

Classes

Multiplexer

Multiplexing over a list of composables.

MultiplexingStrategy

Enum of possible multiplexing strategies.

MuxIterator

Convenience dataclass to keep track of iterator position and reinits.

Attributes

logger

squirrel.iterstream.multiplexer.logger
class squirrel.iterstream.multiplexer.Multiplexer(composables: List[squirrel.iterstream.Composable], mux_strategy: MultiplexingStrategy, sampling_probas: Optional[List[float]] = None, max_reinits: Optional[int] = None, seed: Optional[int] = None, proba_threshold: float = 0.001, **kwargs)

Bases: squirrel.iterstream.Composable

Multiplexing over a list of composables.

The challenge of multiplexing from a iterstream is the lack of knowledge of the length of each stream. Hence it is important to understand that the only option to control the sampling ratio is to fix the dataset distribution ratio AT THE OUTPUT of the iterator, i.e. the ratio of datapoints after collection the samples of the composable. In other words: If T_{eff} is the token number of training dataset then

w_{i, eff} = t_{i, eff} / T_{eff}

is the effective ratio of dataset i and t_{i, eff} the number of tokens it contributes to the final output dataset. Defining the number of tokens in dataset i as t_i, then the oversampling ratio is defined as

e_i = t_{i, eff} / t_i.

The quantity e_i is also the number of epochs that dataset i is being iterated over during training on T_{eff} tokens.

Hence we have two knobs we can control in the multiplexing algorithm:

  1. The sampling probabilities w_{i, eff}

  2. The maximum oversampling epochs e_i

Note that these two quanitities are not independent of each other and fixing both conditions at the same time does not guarantee that both are fulfilled at the same time.

Initializes a multiplexer of a list of data iterstream composables.

Parameters
  • composables

    List of composables corresponding to different data sources. Note: Ensure that the composable are properly split according to multiprocessing worker and / or GPU ranks. There are several ways to do this, and we encourage reading the documentation for the corresponding driver or the composable you are using. As an example you can load a message pack drive using: ```python

    c0 = MessagepackDriver(url=local_msgpack_url).get_iter(key_hooks=[SplitByWorker]) c1 = MessagepackDriver(url=local_msgpack_url).get_iter(key_hooks=[SplitByWorker])

    mux = Multiplexer([c0, c1]).to_torch_iterable(False, False) ```

  • mux_strategy – Multiplexing strategy.

  • sampling_probas – (optional) list of floats that determine the sampling ratios for each dataset, i.e. w_{i, eff} from the docstrings.

  • max_reinits – (optional) int that determines the maximum oversampling epochs over all datasets.

  • seed – (optional) int for the random number generator.

  • proba_threshold – float to indicate when a composable should be ignored in sampling.

Note that the algorithm stops whenever max_reinits are hit or all composables have been reinitialized at least once.

__iter__()Iterator[Dict[str, str]]

Returns the multiplexed iterator.

property num_samplesList[int]

Return number of samples seen from each composable.

property reinit_countsList[int]

Return number of reinits for non-zero probability composables.

class squirrel.iterstream.multiplexer.MultiplexingStrategy

Bases: enum.Enum

Enum of possible multiplexing strategies.

ROUND_ROBIN = RoundRobin
SAMPLING = sampling
UNIFORM_WITH_REPLACEMENT = uniform_with_replacement
class squirrel.iterstream.multiplexer.MuxIterator

Convenience dataclass to keep track of iterator position and reinits.

index :int
it :Iterator
reinits :int