Skip to content

Abstract Dataloader Specifications

Abstract Dataloader Specifications.

The implementations here provide "duck type" protocol definitions of key data loading primitives. In order to implement the specification, users simply need to "fill in" the methods described here for the types which they wish to implement.

Type Parameters

ADL specification protocol types are defined as generics, which are parameterized by other types. These type parameters are documented by a Type Parameters section where applicable.

Composition Rules

ADL protocols which can be composed together are documented by a Composition Rules section.

abstract_dataloader.spec.Metadata

Bases: Protocol

Sensor metadata.

All sensor metadata is expected to be held in memory during training, so great effort should be taken to minimize its memory usage. Any additional information which is not strictly necessary for book-keeping, or which takes more than negligible space, should be loaded as data instead.

Note

This can be a @dataclass, typing.NamedTuple, or a fully custom type - it just has to expose a timestamps attribute.

Attributes:

Name Type Description
timestamps Float[ndarray, N]

measurement timestamps, in seconds. Nominally in epoch time; must be consistent within each trace (but not necessarily across traces). Suggested type: float64, which gives precision of <1us.

Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Metadata(Protocol):
    """Sensor metadata.

    All sensor metadata is expected to be held in memory during training, so
    great effort should be taken to minimize its memory usage. Any additional
    information which is not strictly necessary for book-keeping, or which
    takes more than negligible space, should be loaded as data instead.

    !!! note

        This can be a `@dataclass`, [`typing.NamedTuple`][typing.NamedTuple],
        or a fully custom type - it just has to expose a `timestamps`
        attribute.

    Attributes:
        timestamps: measurement timestamps, in seconds. Nominally in epoch
            time; must be consistent within each trace (but not necessarily
            across traces). Suggested type: `float64,` which gives precision of
            <1us.
    """

    timestamps: Float[np.ndarray, "N"]

abstract_dataloader.spec.Sensor

Bases: Protocol, Generic[TSample, TMetadata]

A sensor, consisting of a synchronous time-series of measurements.

This protocol is parameterized by generic TSample and TMetadata types, which can encode the expected data type of this sensor. For example:

class Point2D(TypedDict):
    x: float
    y: float

def point_transform(point_sensor: Sensor[Point2D, Any]) -> T:
    ...

This encodes an argument, point_sensor, which expected to be a sensor that reads data with type Point2D, but does not specify a metadata type.

Type Parameters
  • TSample: sample data type which this Sensor returns. As a convention, we suggest returning "batched" data by default, i.e. with a leading singleton axis.
  • TMetadata: metadata type associated with this sensor; must implement Metadata.

Attributes:

Name Type Description
metadata TMetadata

sensor metadata, including timestamp information.

Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Sensor(Protocol, Generic[TSample, TMetadata]):
    """A sensor, consisting of a synchronous time-series of measurements.

    This protocol is parameterized by generic `TSample` and `TMetadata` types,
    which can encode the expected data type of this sensor. For example:

    ```python
    class Point2D(TypedDict):
        x: float
        y: float

    def point_transform(point_sensor: Sensor[Point2D, Any]) -> T:
        ...
    ```

    This encodes an argument, `point_sensor`, which expected to be a sensor
    that reads data with type `Point2D`, but does not specify a metadata type.

    Type Parameters:
        - `TSample`: sample data type which this `Sensor` returns. As a
            convention, we suggest returning "batched" data by default, i.e.
            with a leading singleton axis.
        - `TMetadata`: metadata type associated with this sensor; must
            implement [`Metadata`][^.].

    Attributes:
        metadata: sensor metadata, including timestamp information.
    """

    metadata: TMetadata

    def stream(self) -> Iterator[TSample]:
        """Stream values recorded by this sensor.

        Returns:
            An iterator yielding successive samples.
        """
        ...

    def __getitem__(self, index: int | np.integer) -> TSample:
        """Fetch measurements from this sensor, by index.

        Args:
            index: sample index, in the sensor scope.

        Returns:
            Loaded sample.
        """
        ...

    def __len__(self) -> int:
        """Total number of measurements."""
        ...

__getitem__

__getitem__(index: int | integer) -> TSample

Fetch measurements from this sensor, by index.

Parameters:

Name Type Description Default
index int | integer

sample index, in the sensor scope.

required

Returns:

Type Description
TSample

Loaded sample.

Source code in src/abstract_dataloader/spec.py
def __getitem__(self, index: int | np.integer) -> TSample:
    """Fetch measurements from this sensor, by index.

    Args:
        index: sample index, in the sensor scope.

    Returns:
        Loaded sample.
    """
    ...

__len__

__len__() -> int

Total number of measurements.

Source code in src/abstract_dataloader/spec.py
def __len__(self) -> int:
    """Total number of measurements."""
    ...

stream

stream() -> Iterator[TSample]

Stream values recorded by this sensor.

Returns:

Type Description
Iterator[TSample]

An iterator yielding successive samples.

Source code in src/abstract_dataloader/spec.py
def stream(self) -> Iterator[TSample]:
    """Stream values recorded by this sensor.

    Returns:
        An iterator yielding successive samples.
    """
    ...

abstract_dataloader.spec.Synchronization

Bases: Protocol

Synchronization protocol for asynchronous time-series.

Defines a rule for creating matching sensor index tuples which correspond to some kind of global index.

Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Synchronization(Protocol):
    """Synchronization protocol for asynchronous time-series.

    Defines a rule for creating matching sensor index tuples which correspond
    to some kind of global index.
    """

    def __call__(
        self, timestamps: dict[str, Float[np.ndarray, "_N"]]
    ) -> dict[str, Integer[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: sensor timestamps. Each key denotes a different sensor
                name, and the value denotes the timestamps for that sensor.

        Returns:
            A dictionary, where keys correspond to each sensor, and values
                correspond to the indices which map global indices to sensor
                indices, i.e. `global[sensor, i] = sensor[sync[sensor][i]]`.
        """
        ...

__call__

__call__(timestamps: dict[str, Float[ndarray, _N]]) -> dict[str, Integer[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float[ndarray, _N]]

sensor timestamps. Each key denotes a different sensor name, and the value denotes the timestamps for that sensor.

required

Returns:

Type Description
dict[str, Integer[ndarray, M]]

A dictionary, where keys correspond to each sensor, and values correspond to the indices which map global indices to sensor indices, i.e. global[sensor, i] = sensor[sync[sensor][i]].

Source code in src/abstract_dataloader/spec.py
def __call__(
    self, timestamps: dict[str, Float[np.ndarray, "_N"]]
) -> dict[str, Integer[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: sensor timestamps. Each key denotes a different sensor
            name, and the value denotes the timestamps for that sensor.

    Returns:
        A dictionary, where keys correspond to each sensor, and values
            correspond to the indices which map global indices to sensor
            indices, i.e. `global[sensor, i] = sensor[sync[sensor][i]]`.
    """
    ...

abstract_dataloader.spec.Trace

Bases: Protocol, Generic[TSample]

A trace, consisting of multiple simultaneously-recording sensors.

This protocol is parameterized by a generic Sample type, which can encode the expected data type of this trace.

Type Parameters
  • Sample: sample data type which this Trace returns. As a convention, we suggest returning "batched" data by default, i.e. with a leading singleton axis.
Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Trace(Protocol, Generic[TSample]):
    """A trace, consisting of multiple simultaneously-recording sensors.

    This protocol is parameterized by a generic `Sample` type, which can encode
    the expected data type of this trace.

    Type Parameters:
        - `Sample`: sample data type which this `Trace` returns. As a
            convention, we suggest returning "batched" data by default, i.e.
            with a leading singleton axis.
    """

    @overload
    def __getitem__(self, index: str) -> Sensor: ...

    @overload
    def __getitem__(self, index: int | np.integer) -> TSample: ...

    def __getitem__(self, index: int | np.integer | str) -> TSample | Sensor:
        """Get item from global index (or fetch a sensor by name).

        !!! info

            For the user's convenience, traces can be indexed by a `str` sensor
            name, returning that [`Sensor`][^^.]. While we are generally wary
            of requiring "quality of life" features, we include this since a
            simple `isinstance(index, str)` check suffices to implement this
            feature.

        Args:
            index: sample index, or sensor name.

        Returns:
            Loaded sample if `index` is an integer type, or the appropriate
            [`Sensor`][^^.] if `index` is a `str`.
        """
        ...

    def __len__(self) -> int:
        """Total number of sensor-tuple samples."""
        ...

__getitem__

__getitem__(index: str) -> Sensor
__getitem__(index: int | integer) -> TSample
__getitem__(index: int | integer | str) -> TSample | Sensor

Get item from global index (or fetch a sensor by name).

Info

For the user's convenience, traces can be indexed by a str sensor name, returning that Sensor. While we are generally wary of requiring "quality of life" features, we include this since a simple isinstance(index, str) check suffices to implement this feature.

Parameters:

Name Type Description Default
index int | integer | str

sample index, or sensor name.

required

Returns:

Type Description
TSample | Sensor

Loaded sample if index is an integer type, or the appropriate

TSample | Sensor

Sensor if index is a str.

Source code in src/abstract_dataloader/spec.py
def __getitem__(self, index: int | np.integer | str) -> TSample | Sensor:
    """Get item from global index (or fetch a sensor by name).

    !!! info

        For the user's convenience, traces can be indexed by a `str` sensor
        name, returning that [`Sensor`][^^.]. While we are generally wary
        of requiring "quality of life" features, we include this since a
        simple `isinstance(index, str)` check suffices to implement this
        feature.

    Args:
        index: sample index, or sensor name.

    Returns:
        Loaded sample if `index` is an integer type, or the appropriate
        [`Sensor`][^^.] if `index` is a `str`.
    """
    ...

__len__

__len__() -> int

Total number of sensor-tuple samples.

Source code in src/abstract_dataloader/spec.py
def __len__(self) -> int:
    """Total number of sensor-tuple samples."""
    ...

abstract_dataloader.spec.Dataset

Bases: Protocol, Generic[TSample]

A dataset, consisting of multiple traces concatenated together.

Trace subtypes Dataset

Due to the type signatures, a Trace is actually a subtype of Dataset. This means that a dataset which implements a collection of traces can also take a collection of datasets!

Type Parameters
  • TSample: sample data type which this Dataset returns. As a convention, we suggest returning "batched" data by default, i.e. with a leading singleton axis.
Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Dataset(Protocol, Generic[TSample]):
    """A dataset, consisting of multiple traces concatenated together.

    !!! note "[`Trace`][^.] subtypes [`Dataset`][^.]"

        Due to the type signatures, a [`Trace`][^.] is actually a subtype of
        [`Dataset`][^.]. This means that a dataset which implements a
        collection of traces can also take a collection of datasets!

    Type Parameters:
        - `TSample`: sample data type which this `Dataset` returns. As a
            convention, we suggest returning "batched" data by default, i.e.
            with a leading singleton axis.
    """

    def __getitem__(self, index: int | np.integer) -> TSample:
        """Fetch item from this dataset by global index.

        Args:
            index: sample index.

        Returns:
            Loaded sample.
        """
        ...

    def __len__(self) -> int:
        """Total number of samples in this dataset."""
        ...

__getitem__

__getitem__(index: int | integer) -> TSample

Fetch item from this dataset by global index.

Parameters:

Name Type Description Default
index int | integer

sample index.

required

Returns:

Type Description
TSample

Loaded sample.

Source code in src/abstract_dataloader/spec.py
def __getitem__(self, index: int | np.integer) -> TSample:
    """Fetch item from this dataset by global index.

    Args:
        index: sample index.

    Returns:
        Loaded sample.
    """
    ...

__len__

__len__() -> int

Total number of samples in this dataset.

Source code in src/abstract_dataloader/spec.py
def __len__(self) -> int:
    """Total number of samples in this dataset."""
    ...

abstract_dataloader.spec.Transform

Bases: Protocol, Generic[TRaw, TTransformed]

Sample or batch-wise transform.

Note

This protocol is a suggestively-named equivalent to Callable[[TRaw], TTransformed] or Callable[[Any], Any].

Composition Rules
  • Transform can be freely composed, as long as each transform's TTransformed matches the next transform's TRaw; this composition is implemented by abstract.Transform.
  • Composed Transforms result in another Transform:
    Transform[T2, T3] (.) Transform[T1, T2] = Transform[T1, T3].
    
Type Parameters
  • TRaw: Input data type.
  • TTransformed: Output data type.
Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Transform(Protocol, Generic[TRaw, TTransformed]):
    """Sample or batch-wise transform.

    !!! note

        This protocol is a suggestively-named equivalent to
        `Callable[[TRaw], TTransformed]` or `Callable[[Any], Any]`.

    Composition Rules:
        - `Transform` can be freely composed, as long as each transform's
          `TTransformed` matches the next transform's `TRaw`; this composition
          is implemented by [`abstract.Transform`][abstract_dataloader.].
        - Composed `Transform`s result in another `Transform`:
          ```
          Transform[T2, T3] (.) Transform[T1, T2] = Transform[T1, T3].
          ```

    Type Parameters:
        - `TRaw`: Input data type.
        - `TTransformed`: Output data type.
    """

    def __call__(self, data: TRaw) -> TTransformed:
        """Apply transform to a single sample.

        Args:
            data: A single `TRaw` data sample.

        Returns:
            A single `TTransformed` data sample.
        """
        ...

__call__

__call__(data: TRaw) -> TTransformed

Apply transform to a single sample.

Parameters:

Name Type Description Default
data TRaw

A single TRaw data sample.

required

Returns:

Type Description
TTransformed

A single TTransformed data sample.

Source code in src/abstract_dataloader/spec.py
def __call__(self, data: TRaw) -> TTransformed:
    """Apply transform to a single sample.

    Args:
        data: A single `TRaw` data sample.

    Returns:
        A single `TTransformed` data sample.
    """
    ...

abstract_dataloader.spec.Collate

Bases: Protocol, Generic[TTransformed, TCollated]

Data collation.

Note

This protocol is a equivalent to Callable[[Sequence[TTransformed]], TCollated]. Collate can also be viewed as a special case of Transform, where the input type TRaw must be a Sequence[...].

Composition Rules
  • Collate can only be composed in parallel, and can never be sequentially composed.
Type Parameters
  • TTransformed: Input data type.
  • TCollated: Output data type.
Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Collate(Protocol, Generic[TTransformed, TCollated]):
    """Data collation.

    !!! note

        This protocol is a equivalent to
        `Callable[[Sequence[TTransformed]], TCollated]`. `Collate` can also
        be viewed as a special case of `Transform`, where the input type
        `TRaw` must be a `Sequence[...]`.

    Composition Rules:
        - `Collate` can only be composed in parallel, and can never be
          sequentially composed.

    Type Parameters:
        - `TTransformed`: Input data type.
        - `TCollated`: Output data type.
    """

    def __call__(self, data: Sequence[TTransformed]) -> TCollated:
        """Collate a set of samples.

        Args:
            data: A set of `TTransformed` samples.

        Returns:
            A `TCollated` batch.
        """
        ...

__call__

__call__(data: Sequence[TTransformed]) -> TCollated

Collate a set of samples.

Parameters:

Name Type Description Default
data Sequence[TTransformed]

A set of TTransformed samples.

required

Returns:

Type Description
TCollated

A TCollated batch.

Source code in src/abstract_dataloader/spec.py
def __call__(self, data: Sequence[TTransformed]) -> TCollated:
    """Collate a set of samples.

    Args:
        data: A set of `TTransformed` samples.

    Returns:
        A `TCollated` batch.
    """
    ...

abstract_dataloader.spec.Pipeline

Bases: Protocol, Generic[TRaw, TTransformed, TCollated, TProcessed]

Dataloader transform pipeline.

This protocol is parameterized by four type variables which encode the different data formats at each stage in the pipeline. This forms a Raw -> Transformed -> Collated -> Processed pipeline with three transforms:

  • sample: a sample to sample transform; can be sequentially assembled from one or more Transforms.
  • collate: a list-of-samples to batch transform. Can use exactly one Collate.
  • batch: a batch to batch transform; can be sequentially assembled from one or more Transforms.
Composition Rules
Type Parameters
  • TRaw: Input data format.
  • TTransformed: Data after the first transform step.
  • TCollated: Data after the second collate step.
  • TProcessed: Output data format.
Source code in src/abstract_dataloader/spec.py
@runtime_checkable
class Pipeline(
    Protocol, Generic[TRaw, TTransformed, TCollated, TProcessed]
):
    """Dataloader transform pipeline.

    This protocol is parameterized by four type variables which encode the
    different data formats at each stage in the pipeline. This forms a
    `Raw -> Transformed -> Collated -> Processed` pipeline with three
    transforms:

    - [`sample`][.]: a sample to sample transform; can be sequentially
      assembled from one or more [`Transform`][^.]s.
    - [`collate`][.]: a list-of-samples to batch transform. Can use exactly one
      [`Collate`][^.].
    - [`batch`][.]: a batch to batch transform; can be sequentially assembled
      from one or more [`Transform`][^.]s.

    Composition Rules:
        - A full `Pipeline` can be sequentially pre-composed and/or
          post-composed with one or more [`Transform`][^.]s; this is
          implemented by [`generic.ComposedPipeline`][abstract_dataloader.].
        - `Pipeline`s can always be composed in parallel; this is implemented
          by [`generic.ParallelPipelines`][abstract_dataloader.], with a
          pytorch [`nn.Module`][torch.]-compatible version in
          [`torch.ParallelPipelines`][abstract_dataloader.].

    Type Parameters:
        - `TRaw`: Input data format.
        - `TTransformed`: Data after the first `transform` step.
        - `TCollated`: Data after the second `collate` step.
        - `TProcessed`: Output data format.
    """

    def sample(self, data: TRaw) -> TTransformed:
        """Transform single samples.

        - Operates on single samples, nominally on the CPU-side of a
          dataloader.
        - This method is both sequentially and parallel composable.

        Args:
            data: A single `TRaw` data sample.

        Returns:
            A single `TTransformed` data sample.
        """
        ...

    def collate(self, data: Sequence[TTransformed]) -> TCollated:
        """Collate a list of data samples into a GPU-ready batch.

        - Operates on the CPU-side of the dataloader, and is responsible for
          aggregating individual samples into a batch (but not transferring to
          the GPU).
        - Analogous to the `collate_fn` of a
          [pytorch dataloader](https://pytorch.org/docs/stable/data.html).
        - This method is not sequentially composable.

        Args:
            data: A sequence of `TTransformed` data samples.

        Returns:
            A `TCollated` collection of the input sequence.
        """
        ...

    def batch(self, data: TCollated) -> TProcessed:
        """Transform data batch.

        - Operates on a batch of data, nominally on the GPU-side of a
          dataloader.
        - This method is both sequentially and parallel composable.

        !!! info "Implementation as `torch.nn.Module`"

            If these `Transforms` require GPU state, it may be helpful to
            implement it as a `torch.nn.Module`. In this case, `batch`
            should redirect to `__call__`, which in turn redirects to
            [`nn.Module.forward`][torch.] in order to handle any registered
            pytorch hooks.

        Args:
            data: A `TCollated` batch of data, nominally already sent to the
                GPU.

        Returns:
            The `TProcessed` output, ready for the downstream model.
        """
        ...

batch

batch(data: TCollated) -> TProcessed

Transform data batch.

  • Operates on a batch of data, nominally on the GPU-side of a dataloader.
  • This method is both sequentially and parallel composable.

Implementation as torch.nn.Module

If these Transforms require GPU state, it may be helpful to implement it as a torch.nn.Module. In this case, batch should redirect to __call__, which in turn redirects to nn.Module.forward in order to handle any registered pytorch hooks.

Parameters:

Name Type Description Default
data TCollated

A TCollated batch of data, nominally already sent to the GPU.

required

Returns:

Type Description
TProcessed

The TProcessed output, ready for the downstream model.

Source code in src/abstract_dataloader/spec.py
def batch(self, data: TCollated) -> TProcessed:
    """Transform data batch.

    - Operates on a batch of data, nominally on the GPU-side of a
      dataloader.
    - This method is both sequentially and parallel composable.

    !!! info "Implementation as `torch.nn.Module`"

        If these `Transforms` require GPU state, it may be helpful to
        implement it as a `torch.nn.Module`. In this case, `batch`
        should redirect to `__call__`, which in turn redirects to
        [`nn.Module.forward`][torch.] in order to handle any registered
        pytorch hooks.

    Args:
        data: A `TCollated` batch of data, nominally already sent to the
            GPU.

    Returns:
        The `TProcessed` output, ready for the downstream model.
    """
    ...

collate

collate(data: Sequence[TTransformed]) -> TCollated

Collate a list of data samples into a GPU-ready batch.

  • Operates on the CPU-side of the dataloader, and is responsible for aggregating individual samples into a batch (but not transferring to the GPU).
  • Analogous to the collate_fn of a pytorch dataloader.
  • This method is not sequentially composable.

Parameters:

Name Type Description Default
data Sequence[TTransformed]

A sequence of TTransformed data samples.

required

Returns:

Type Description
TCollated

A TCollated collection of the input sequence.

Source code in src/abstract_dataloader/spec.py
def collate(self, data: Sequence[TTransformed]) -> TCollated:
    """Collate a list of data samples into a GPU-ready batch.

    - Operates on the CPU-side of the dataloader, and is responsible for
      aggregating individual samples into a batch (but not transferring to
      the GPU).
    - Analogous to the `collate_fn` of a
      [pytorch dataloader](https://pytorch.org/docs/stable/data.html).
    - This method is not sequentially composable.

    Args:
        data: A sequence of `TTransformed` data samples.

    Returns:
        A `TCollated` collection of the input sequence.
    """
    ...

sample

sample(data: TRaw) -> TTransformed

Transform single samples.

  • Operates on single samples, nominally on the CPU-side of a dataloader.
  • This method is both sequentially and parallel composable.

Parameters:

Name Type Description Default
data TRaw

A single TRaw data sample.

required

Returns:

Type Description
TTransformed

A single TTransformed data sample.

Source code in src/abstract_dataloader/spec.py
def sample(self, data: TRaw) -> TTransformed:
    """Transform single samples.

    - Operates on single samples, nominally on the CPU-side of a
      dataloader.
    - This method is both sequentially and parallel composable.

    Args:
        data: A single `TRaw` data sample.

    Returns:
        A single `TTransformed` data sample.
    """
    ...