Skip to content

Pytorch Interoperability

Pytorch-ADL wrappers.

These implementations are a superset of the generic components, and provide interoperability with pytorch dataloaders, modules, etc. For example, any Pipeline-related components which could contain pytorch nn.Modules are modified to subclass nn.Module in order to properly register them.

Warning

This module is not automatically imported; you will need to explicitly import it:

from abstract_dataloader import torch as adl_torch

Since pytorch is not declared as a required dependency, you will also need to install torch (or install the torch extra with pip install abstract_dataloader[torch]).

Note

Recursive tree operations such as reshaping and stacking are performed using the optree library, or, if that is not present, torch.utils._pytree, which implements equivalent functionality. If torch.utils._pytree is removed in a later version, the constructor will raise NotImplementedError, and this fallback will need to be replaced.

Warning

Custom data container classes such as @dataclass are only supported if optree is installed, and they are registered with optree. However, dict, list, tuple, and equivalent types such as TypedDict and NamedTuple will work out of the box.

abstract_dataloader.torch.Collate

Bases: Collate[TTransformed, TCollated]

Generic numpy to pytorch collation.

Converts numpy arrays to pytorch tensors, and either stacks or concatenates each value.

Type Parameters
  • TTransformed: input sample type.
  • TCollated: output collated type.

Parameters:

Name Type Description Default
mode Literal['stack', 'concat']

whether to stack or concat during collation.

'concat'
Source code in src/abstract_dataloader/torch/torch.py
class Collate(spec.Collate[TTransformed, TCollated]):
    """Generic numpy to pytorch collation.

    Converts numpy arrays to pytorch tensors, and either stacks or concatenates
    each value.

    Type Parameters:
        - `TTransformed`: input sample type.
        - `TCollated`: output collated type.

    Args:
        mode: whether to `stack` or `concat` during collation.
    """

    def __init__(self, mode: Literal["stack", "concat"] = "concat") -> None:
        self.mode = mode
        self.treelib = _get_treelib()

    def __call__(self, data: Sequence[TTransformed]) -> TCollated:
        if self.mode == "concat":
            return self.treelib.tree_map(
                lambda *x: torch.concat([torch.from_numpy(s) for s in x]),
                *data)  # type: ignore
        else:
            return self.treelib.tree_map(
                lambda *x: torch.stack([torch.from_numpy(s) for s in x]),
                *data)  # type: ignore

abstract_dataloader.torch.ComposedPipeline

Bases: Module, ComposedPipeline[TRaw, TRawInner, TTransformed, TCollated, TProcessedInner, TProcessed]

Compose pipeline sequentially with pre and post transforms.

Type Parameters
  • TRaw: initial input type.
  • TRawInner: output of the pre-composed transform, and input to the provided Pipeline.
  • TCollated, TProcessed: intermediate values for the provided Pipeline.
  • TProcessedInner: output of the transforms, and input to the post-composed transform.
  • TProcessed: output type.

Parameters:

Name Type Description Default
pipeline Pipeline[TRawInner, TTransformed, TCollated, TProcessedInner]

pipeline to compose.

required
pre Transform[TRaw, TRawInner] | None

pre-transform to apply on the CPU side; skipped if None.

None
post Transform[TProcessedInner, TProcessed] | None

post-transform to apply on the GPU side; skipped if None.

None
Source code in src/abstract_dataloader/torch/generic.py
class ComposedPipeline(
    torch.nn.Module,
    generic.ComposedPipeline[
        TRaw, TRawInner, TTransformed, TCollated, TProcessedInner, TProcessed]
):
    """Compose pipeline sequentially with pre and post transforms.

    Type Parameters:
        - `TRaw`: initial input type.
        - `TRawInner`: output of the pre-composed transform, and input to the
            provided [`Pipeline`][abstract_dataloader.spec].
        - `TCollated`, `TProcessed`: intermediate values for the provided
            [`Pipeline`][abstract_dataloader.spec].
        - `TProcessedInner`: output of the transforms, and input to the
            post-composed transform.
        - `TProcessed`: output type.

    Args:
        pipeline: pipeline to compose.
        pre: pre-transform to apply on the CPU side; skipped if `None`.
        post: post-transform to apply on the GPU side; skipped if `None`.
    """

    def __init__(
        self, pipeline: spec.Pipeline[
            TRawInner, TTransformed, TCollated, TProcessedInner],
        pre: spec.Transform[TRaw, TRawInner] | None = None,
        post: spec.Transform[TProcessedInner, TProcessed] | None = None
    ) -> None:
        super().__init__()
        self.pipeline = pipeline
        self.pre = pre
        self.post = post

        self.collate = pipeline.collate

abstract_dataloader.torch.Empty

Bases: Synchronization

Dummy synchronization which does not synchronize sensor pairs.

No samples will be registered, and the trace can only be used as a collection of sensors.

Source code in src/abstract_dataloader/generic/sync.py
class Empty(spec.Synchronization):
    """Dummy synchronization which does not synchronize sensor pairs.

    No samples will be registered, and the trace can only be used as a
    collection of sensors.
    """

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        return {k: np.array([], dtype=np.uint32) for k in timestamps}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    return {k: np.array([], dtype=np.uint32) for k in timestamps}

abstract_dataloader.torch.Metadata dataclass

Bases: Metadata

Generic metadata with timestamps.

Attributes:

Name Type Description
timestamps Float64[ndarray, N]

epoch timestamps.

Source code in src/abstract_dataloader/generic/sequence.py
@dataclass
class Metadata(spec.Metadata):
    """Generic metadata with timestamps.

    Attributes:
        timestamps: epoch timestamps.
    """

    timestamps: Float64[np.ndarray, "N"]

abstract_dataloader.torch.Nearest

Bases: Synchronization

Nearest sample synchronization, with respect to a reference sensor.

Applies the following:

  • Compute the midpoints between observations between each sensor.
  • Find which bin the reference sensor timestamps fall into.
  • Calculate the resulting time delta between timestamps. If this exceeds tol for any sensor-reference pair, remove this match.

See Synchronization for protocol details.

Parameters:

Name Type Description Default
reference str

reference sensor to synchronize to.

required
tol float

synchronization time tolerance, in seconds. Setting tol = np.inf works to disable this check altogether.

0.1
Source code in src/abstract_dataloader/generic/sync.py
class Nearest(spec.Synchronization):
    """Nearest sample synchronization, with respect to a reference sensor.

    Applies the following:

    - Compute the midpoints between observations between each sensor.
    - Find which bin the reference sensor timestamps fall into.
    - Calculate the resulting time delta between timestamps. If this exceeds
      `tol` for any sensor-reference pair, remove this match.

    See [`Synchronization`][abstract_dataloader.spec.] for protocol details.

    Args:
        reference: reference sensor to synchronize to.
        tol: synchronization time tolerance, in seconds. Setting `tol = np.inf`
            works to disable this check altogether.
    """

    def __init__(self, reference: str, tol: float = 0.1) -> None:
        if tol < 0:
            raise ValueError(
                f"Synchronization tolerance must be positive: {tol} < 0")

        self.tol = tol
        self.reference = reference

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        try:
            t_ref = timestamps[self.reference]
        except KeyError:
            raise KeyError(
                f"Reference sensor {self.reference} was not provided in "
                f"timestamps, with keys: {list(timestamps.keys())}")

        indices = {
            k: np.searchsorted(
                (t_sensor[:-1] + t_sensor[1:]) / 2, t_ref
            ).astype(np.uint32)
            for k, t_sensor in timestamps.items()}
        valid = np.all(np.array([
           np.abs(timestamps[k][i_nearest] - t_ref) < self.tol
        for k, i_nearest in indices.items()]), axis=0)

        return {k: v[valid] for k, v in indices.items()}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    try:
        t_ref = timestamps[self.reference]
    except KeyError:
        raise KeyError(
            f"Reference sensor {self.reference} was not provided in "
            f"timestamps, with keys: {list(timestamps.keys())}")

    indices = {
        k: np.searchsorted(
            (t_sensor[:-1] + t_sensor[1:]) / 2, t_ref
        ).astype(np.uint32)
        for k, t_sensor in timestamps.items()}
    valid = np.all(np.array([
       np.abs(timestamps[k][i_nearest] - t_ref) < self.tol
    for k, i_nearest in indices.items()]), axis=0)

    return {k: v[valid] for k, v in indices.items()}

abstract_dataloader.torch.Next

Bases: Synchronization

Next sample synchronization, with respect to a reference sensor.

Applies the following:

  • Find the start time, defined by the earliest time which is observed by all sensors, and the end time, defined by the last time which is observed by all sensors.
  • Truncate the reference sensor's timestamps to this start and end time, and use this as the query timestamps.
  • For each time in the query, find the first sample from each sensor which is after this time.

See Synchronization for protocol details.

Parameters:

Name Type Description Default
reference str

reference sensor to synchronize to.

required
Source code in src/abstract_dataloader/generic/sync.py
class Next(spec.Synchronization):
    """Next sample synchronization, with respect to a reference sensor.

    Applies the following:

    - Find the start time, defined by the earliest time which is observed by
      all sensors, and the end time, defined by the last time which is observed
      by all sensors.
    - Truncate the reference sensor's timestamps to this start and end time,
      and use this as the query timestamps.
    - For each time in the query, find the first sample from each sensor which
      is after this time.

    See [`Synchronization`][abstract_dataloader.spec.] for protocol details.

    Args:
        reference: reference sensor to synchronize to.
    """

    def __init__(self, reference: str) -> None:
        self.reference = reference

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        try:
            ref_time_all = timestamps[self.reference]
        except KeyError:
            raise KeyError(
                f"Reference sensor {self.reference} was not provided in "
                f"timestamps, with keys: {list(timestamps.keys())}")

        start_time = max(t[0] for t in timestamps.values())
        end_time = min(t[-1] for t in timestamps.values())

        start_idx = np.searchsorted(ref_time_all, start_time)
        end_idx = np.searchsorted(ref_time_all, end_time)
        ref_time = ref_time_all[start_idx:end_idx]
        return {
            k: np.searchsorted(v, ref_time).astype(np.uint32)
            for k, v in timestamps.items()}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    try:
        ref_time_all = timestamps[self.reference]
    except KeyError:
        raise KeyError(
            f"Reference sensor {self.reference} was not provided in "
            f"timestamps, with keys: {list(timestamps.keys())}")

    start_time = max(t[0] for t in timestamps.values())
    end_time = min(t[-1] for t in timestamps.values())

    start_idx = np.searchsorted(ref_time_all, start_time)
    end_idx = np.searchsorted(ref_time_all, end_time)
    ref_time = ref_time_all[start_idx:end_idx]
    return {
        k: np.searchsorted(v, ref_time).astype(np.uint32)
        for k, v in timestamps.items()}

abstract_dataloader.torch.ParallelPipelines

Bases: Module, ParallelPipelines[PRaw, PTransformed, PCollated, PProcessed]

Transform Compositions, modified for Pytorch compatibility.

Any nn.Module transforms are registered to a separate nn.ModuleDict; the original .transforms attribute is maintained with references to the full pipeline.

See generic.ParallelPipelines for more details about this implementation. .forward and .__call__ should work as expected within pytorch.

Type Parameters
  • PRaw, PTransformed, PCollated, PProcessed: see Pipeline.

Parameters:

Name Type Description Default
transforms Pipeline

pipelines to compose. The key indicates the subkey to apply each transform to.

{}
Source code in src/abstract_dataloader/torch/generic.py
class ParallelPipelines(
    torch.nn.Module,
    generic.ParallelPipelines[PRaw, PTransformed, PCollated, PProcessed]
):
    """Transform Compositions, modified for Pytorch compatibility.

    Any [`nn.Module`][?torch.] transforms are registered to a separate
    [`nn.ModuleDict`][?torch.]; the original `.transforms` attribute is
    maintained with references to the full pipeline.

    See [`generic.ParallelPipelines`][abstract_dataloader.]
    for more details about this implementation. `.forward` and `.__call__`
    should work as expected within pytorch.

    Type Parameters:
        - `PRaw`, `PTransformed`, `PCollated`, `PProcessed`: see
          [`Pipeline`][abstract_dataloader.spec.].

    Args:
        transforms: pipelines to compose. The key indicates the subkey to
            apply each transform to.
    """

    def __init__(self, **transforms: spec.Pipeline) -> None:
        super().__init__()
        self.transforms = transforms
        self._transforms = torch.nn.ModuleDict({
            k: v for k, v in transforms.items()
            if isinstance(v, torch.nn.Module)})

    def forward(self, data: PCollated) -> PProcessed:
        # We have to redefine this for some reason to make torch happy.
        # I think `nn.Module` has a generic `forward` implementation which
        # is clobbering `ComposeTransform`.
        return cast(
            PProcessed,
            {k: v.batch(data[k]) for k, v in self.transforms.items()})

    def batch(self, data: PCollated) -> PProcessed:
        """Alias `batch` to `__call__` to `forward` via `nn.Module`."""
        return self(data)

batch

batch(data: PCollated) -> PProcessed

Alias batch to __call__ to forward via nn.Module.

Source code in src/abstract_dataloader/torch/generic.py
def batch(self, data: PCollated) -> PProcessed:
    """Alias `batch` to `__call__` to `forward` via `nn.Module`."""
    return self(data)

abstract_dataloader.torch.SequencePipeline

Bases: Module, SequencePipeline[TRaw, TTransformed, TCollated, TProcessed]

Transform which passes an additional sequence axis through.

The given Pipeline is modified to accept Sequence[...] for each data type in its pipeline, and return a list[...] across the additional axis, thus "passing through" the axis.

For example, suppose a sequence dataloader reads

[
    [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
    [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
    ...
    [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]

for sequence length t = 0...n and batch sample s = 0...b. For sequence length t, the output of the transforms will be batched with the sequence on the outside:

[
    Processed[s=0...b] [t=0],
    Processed[s=0...b] [t=1],
    ...
    Processed[s=0...b] [t=n]
]
Type Parameters
  • TRaw, TTransformed, TCollated, TProcessed: see Pipeline.

Parameters:

Name Type Description Default
pipeline Pipeline[TRaw, TTransformed, TCollated, TProcessed]

input pipeline.

required
Source code in src/abstract_dataloader/torch/generic.py
class SequencePipeline(
    torch.nn.Module,
    generic.SequencePipeline[TRaw, TTransformed, TCollated, TProcessed]
):
    """Transform which passes an additional sequence axis through.

    The given `Pipeline` is modified to accept `Sequence[...]` for each
    data type in its pipeline, and return a `list[...]` across the additional
    axis, thus "passing through" the axis.

    For example, suppose a sequence dataloader reads

    ```
    [
        [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
        [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
        ...
        [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
    ]
    ```

    for sequence length `t = 0...n` and batch sample `s = 0...b`. For sequence
    length `t`, the output of the transforms will be batched with the sequence
    on the outside:

    ```
    [
        Processed[s=0...b] [t=0],
        Processed[s=0...b] [t=1],
        ...
        Processed[s=0...b] [t=n]
    ]
    ```

    Type Parameters:
        - `TRaw`, `TTransformed`, `TCollated`, `TProcessed`: see
          [`Pipeline`][abstract_dataloader.spec.].

    Args:
        pipeline: input pipeline.
    """

    def __init__(
        self, pipeline: spec.Pipeline[
            TRaw, TTransformed, TCollated, TProcessed]
    ) -> None:
        super().__init__()
        self.pipeline = pipeline

abstract_dataloader.torch.StackedSequencePipeline

Bases: Module, Pipeline[Sequence[TRaw], Sequence[TTransformed], TCollated, TProcessed]

Modify a pipeline to act on sequences.

Unlike the generic generic.SequencePipeline implementation, this class places the sequence axis directly inside each tensor, so that each data type has axes (batch, sequence, ...). For the same input,

[
    [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
    [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
    ...
    [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]

this pipeline instead yields

Processed[s=0...b] [t=0...n].

Info

This class requires that all outputs of .collate() are pytorch tensors. Furthermore, batches must be treated as an additional leading axis by both .collate and .forward.

Warning

Since the output has an additional axis, it does not necessarily have the same type as the underlying transform!

This is accomplished by appropriately reshaping the data to use the batch-vectorized underlying implementation:

  • .sample: apply the pipeline to each sample across the additional sequence axis.
  • .collate: concatenate all sequences into a single list[Raw], instead of a list[list[Raw]]. Then, collate the list, and reshape back into batch sequence ... order.
  • .batch: flatten the collated data back to a (batch sequence) ... single leading batch axis, apply the pipeline, and reshape back.
Type Parameters
  • PRaw, PTransformed, PCollated, PProcessed: see Pipeline.

Parameters:

Name Type Description Default
pipeline Pipeline[TRaw, TTransformed, TCollated, TProcessed]

pipeline to transform to accept sequences.

required
Source code in src/abstract_dataloader/torch/torch.py
class StackedSequencePipeline(
    torch.nn.Module,
    spec.Pipeline[
        Sequence[TRaw], Sequence[TTransformed], TCollated, TProcessed]
):
    """Modify a pipeline to act on sequences.

    Unlike the generic [`generic.SequencePipeline`][abstract_dataloader.]
    implementation, this class places the sequence axis directly inside each
    tensor, so that each data type has axes `(batch, sequence, ...)`. For the
    same input,

    ```
    [
        [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
        [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
        ...
        [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
    ]
    ```

    this pipeline instead yields

    ```python
    Processed[s=0...b] [t=0...n].
    ```

    !!! info

        This class requires that all outputs of `.collate()` are pytorch
        tensors. Furthermore, batches must be treated as an additional leading
        axis by both `.collate` and `.forward`.

    !!! warning

        Since the output has an additional axis, it does not necessarily have
        the same type as the underlying transform!

    This is accomplished by appropriately reshaping the data to use the
    batch-vectorized underlying implementation:

    - `.sample`: apply the pipeline to each sample across the additional
      sequence axis.
    - `.collate`: concatenate all sequences into a single `list[Raw]`, instead
      of a `list[list[Raw]]`. Then, collate the list, and reshape back into
      `batch sequence ...` order.
    - `.batch`: flatten the collated data back to a `(batch sequence) ...`
      single leading batch axis, apply the pipeline, and reshape back.

    Type Parameters:
        - `PRaw`, `PTransformed`, `PCollated`, `PProcessed`: see
          [`Pipeline`][abstract_dataloader.spec.].

    Args:
        pipeline: pipeline to transform to accept sequences.
    """

    def __init__(
        self, pipeline: spec.Pipeline[
            TRaw, TTransformed, TCollated, TProcessed]
    ) -> None:
        super().__init__()
        self.pipeline = pipeline
        self.treelib = _get_treelib()

    def sample(self, data: Sequence[TRaw]) -> list[TTransformed]:
        return [self.pipeline.sample(x) for x in data]

    def collate(self, data: Sequence[Sequence[TTransformed]]) -> Any:
        data_flat = sum((list(x) for x in data), start=[])
        collated_flat = self.pipeline.collate(data_flat)
        unflattened = self.treelib.tree_map(
            lambda x: x.reshape(len(data), -1, *x.shape[1:]),
            collated_flat)   # type: ignore
        return unflattened

    def batch(self, data: Any) -> Any:
        batch = self.treelib.tree_leaves(data)[0].shape[0]  # type: ignore
        flattened = self.treelib.tree_map(
            lambda x: x.reshape(-1, *x.shape[2:]), data)
        transformed = self.pipeline.batch(cast(TCollated, flattened))
        unflattened = self.treelib.tree_map(
            lambda x: x.reshape(batch, -1, *x.shape[1:]),
            transformed)  # type: ignore
        return unflattened

abstract_dataloader.torch.TransformedDataset

Bases: Dataset[TTransformed], Generic[TRaw, TTransformed]

Pytorch-compatible dataset with transformation applied.

Extends torch.utils.data.Dataset, implementing a torch "map-style" dataset.

Type Parameters
  • TRaw: raw data type from the dataloader.
  • TTransformed: output data type from the provided transform function.

Parameters:

Name Type Description Default
dataset Dataset[TRaw]

source dataset.

required
transform Transform[TRaw, TTransformed]

transformation to apply to each sample when loading (note that Transform[TRaw, TTransformed] is equivalent to Callable[[TRaw], TTransformed]).

required
Source code in src/abstract_dataloader/torch/torch.py
class TransformedDataset(Dataset[TTransformed], Generic[TRaw, TTransformed]):
    """Pytorch-compatible dataset with transformation applied.

    Extends [`torch.utils.data.Dataset`][?torch.utils.data.Dataset],
    implementing a torch "map-style" dataset.

    Type Parameters:
        - `TRaw`: raw data type from the dataloader.
        - `TTransformed`: output data type from the provided transform function.

    Args:
        dataset: source dataset.
        transform: transformation to apply to each sample when loading (note
            that `Transform[TRaw, TTransformed]` is equivalent to
            `Callable[[TRaw], TTransformed]`).
    """

    def __init__(
        self, dataset: spec.Dataset[TRaw],
        transform: spec.Transform[TRaw, TTransformed]
    ) -> None:
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index: int | np.integer) -> TTransformed:
        """Map-style dataset indexing.

        Args:
            index: dataset index; passthrough to the underlying `Dataset`.

        Returns:
            Transformed sample.
        """
        return self.transform(self.dataset[index])

    def __len__(self) -> int:
        """Dataset length; passthrough to the underlying `Dataset`."""
        return len(self.dataset)

    def __repr__(self) -> str:
        """Friendly name."""
        return f"Transformed({repr(self.dataset)})"

__getitem__

__getitem__(index: int | integer) -> TTransformed

Map-style dataset indexing.

Parameters:

Name Type Description Default
index int | integer

dataset index; passthrough to the underlying Dataset.

required

Returns:

Type Description
TTransformed

Transformed sample.

Source code in src/abstract_dataloader/torch/torch.py
def __getitem__(self, index: int | np.integer) -> TTransformed:
    """Map-style dataset indexing.

    Args:
        index: dataset index; passthrough to the underlying `Dataset`.

    Returns:
        Transformed sample.
    """
    return self.transform(self.dataset[index])

__len__

__len__() -> int

Dataset length; passthrough to the underlying Dataset.

Source code in src/abstract_dataloader/torch/torch.py
def __len__(self) -> int:
    """Dataset length; passthrough to the underlying `Dataset`."""
    return len(self.dataset)

__repr__

__repr__() -> str

Friendly name.

Source code in src/abstract_dataloader/torch/torch.py
def __repr__(self) -> str:
    """Friendly name."""
    return f"Transformed({repr(self.dataset)})"

abstract_dataloader.torch.Window

Bases: Sensor[SampleStack, Metadata], Generic[SampleStack, Sample, TMetadata]

Load sensor data across a time window using a sensor transform.

Use this class as a generic transform to give time history to any sensor:

sensor =  ... # implements spec.Sensor
with_history = generic.Window(sensor, past=5, future=1, parallel=7)

In this example, 5 past samples, the current sample, and 1 future sample are loaded on every index:

with_history[i] = [
    sensor[i], sensor[i + 1], ... sensor[i + 5], sensor[i + 6]]
                                        ^
                            # timestamp for synchronization

Parameters:

Name Type Description Default
sensor Sensor[Sample, TMetadata]

sensor to wrap.

required
collate_fn Callable[[list[Sample]], SampleStack] | None

collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list.

None
past int

number of past samples, in addition to the current sample. Set to 0 to disable.

0
future int

number of future samples, in addition to the current sample. Set to 0 to disable.

0
parallel int | None

maximum number of samples to load in parallel; if None, all samples are loaded sequentially.

None
Type Parameters
  • SampleStack: a collated series of consecutive samples. Can simply be list[Sample].
  • Sample: single observation sample type.
  • TMetadata: metadata type for the underlying sensor. Note that the Window wrapper doesn't actually have metadata type TMetadata; this type is just passed through from the sensor which is wrapped.
Source code in src/abstract_dataloader/generic/sequence.py
class Window(
    abstract.Sensor[SampleStack, Metadata],
    Generic[SampleStack, Sample, TMetadata]
):
    """Load sensor data across a time window using a sensor transform.

    Use this class as a generic transform to give time history to any sensor:

    ```python
    sensor =  ... # implements spec.Sensor
    with_history = generic.Window(sensor, past=5, future=1, parallel=7)
    ```

    In this example, 5 past samples, the current sample, and 1 future sample
    are loaded on every index:

    ```python
    with_history[i] = [
        sensor[i], sensor[i + 1], ... sensor[i + 5], sensor[i + 6]]
                                            ^
                                # timestamp for synchronization
    ```

    Args:
        sensor: sensor to wrap.
        collate_fn: collate function for aggregating a list of samples; if not
            specified, the samples are simply returned as a list.
        past: number of past samples, in addition to the current sample. Set
            to `0` to disable.
        future: number of future samples, in addition to the current sample.
            Set to `0` to disable.
        parallel: maximum number of samples to load in parallel; if `None`, all
            samples are loaded sequentially.

    Type Parameters:
        - `SampleStack`: a collated series of consecutive samples. Can simply be
            `list[Sample]`.
        - `Sample`: single observation sample type.
        - `TMetadata`: metadata type for the underlying sensor. Note that the
            `Window` wrapper doesn't actually have metadata type `TMetadata`;
            this type is just passed through from the sensor which is wrapped.
    """

    def __init__(
        self, sensor: spec.Sensor[Sample, TMetadata],
        collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
        past: int = 0, future: int = 0, parallel: int | None = None
    ) -> None:
        self.sensor = sensor
        self.past = past
        self.future = future
        self.parallel = parallel

        if collate_fn is None:
            collate_fn = cast(
                Callable[[list[Sample]], SampleStack], lambda x: x)
        self.collate_fn = collate_fn

        # hack for negative indexing
        _future = None if future == 0 else -future
        self.metadata = Metadata(
            timestamps=sensor.metadata.timestamps[past:_future])

    @classmethod
    def from_partial_sensor(
        cls, sensor: Callable[[str], spec.Sensor[Sample, TMetadata]],
        collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
        past: int = 0, future: int = 0, parallel: int | None = None
    ) -> Callable[[str], "Window[SampleStack, Sample, TMetadata]"]:
        """Partially initialize from partially initialized sensor.

        Use this to create windowed sensor constructors which can be
        applied to different traces to construct a dataset. For example,
        if you have a `sensor_constructor`:

        ```python
        sensor_constructor = ...
        windowed_sensor_constructor = Window.from_partial_sensor(
            sensor_constructor, ...)

        # ... somewhere inside the dataset constructor
        sensor_instance = windowed_sensor_constructor(path_to_trace)
        ```

        Args:
            sensor: sensor *constructor* to wrap.
            collate_fn: collate function for aggregating a list of samples; if
                not specified, the samples are simply returned as a list.
            past: number of past samples, in addition to the current sample.
                Set to `0` to disable.
            future: number of future samples, in addition to the current
                sample. Set to `0` to disable.
            parallel: maximum number of samples to load in parallel; if `None`,
                all samples are loaded sequentially.
        """
        def create_wrapped_sensor(
            path: str
        ) -> Window[SampleStack, Sample, TMetadata]:
            return cls(
                sensor(path), collate_fn=collate_fn, past=past,
                future=future, parallel=parallel)

        return create_wrapped_sensor

    def __getitem__(self, index: int | np.integer) -> SampleStack:
        """Fetch measurements from this sensor, by index.

        Args:
            index: sample index; note that `past` samples are lost at the
                beginning, and `future` at the end to account for the window
                size.

        Returns:
            A set of `past + 1 + future` consecutives samples. Note that there
                is a `past` offset of indices between the wrapped `Window` and
                the underlying sensor!
        """
        window = list(range(index, index + self.past + self.future + 1))

        if self.parallel is not None:
            with ThreadPool(min(len(window), self.parallel)) as p:
                return self.collate_fn(p.map(self.sensor.__getitem__, window))
        else:
            return self.collate_fn(list(map(self.sensor.__getitem__, window)))

    def __repr__(self) -> str:
        """Get friendly name (passing through to the underlying sensor)."""
        return f"{repr(self.sensor)} x [-{self.past}:+{self.future}]"

__getitem__

__getitem__(index: int | integer) -> SampleStack

Fetch measurements from this sensor, by index.

Parameters:

Name Type Description Default
index int | integer

sample index; note that past samples are lost at the beginning, and future at the end to account for the window size.

required

Returns:

Type Description
SampleStack

A set of past + 1 + future consecutives samples. Note that there is a past offset of indices between the wrapped Window and the underlying sensor!

Source code in src/abstract_dataloader/generic/sequence.py
def __getitem__(self, index: int | np.integer) -> SampleStack:
    """Fetch measurements from this sensor, by index.

    Args:
        index: sample index; note that `past` samples are lost at the
            beginning, and `future` at the end to account for the window
            size.

    Returns:
        A set of `past + 1 + future` consecutives samples. Note that there
            is a `past` offset of indices between the wrapped `Window` and
            the underlying sensor!
    """
    window = list(range(index, index + self.past + self.future + 1))

    if self.parallel is not None:
        with ThreadPool(min(len(window), self.parallel)) as p:
            return self.collate_fn(p.map(self.sensor.__getitem__, window))
    else:
        return self.collate_fn(list(map(self.sensor.__getitem__, window)))

__repr__

__repr__() -> str

Get friendly name (passing through to the underlying sensor).

Source code in src/abstract_dataloader/generic/sequence.py
def __repr__(self) -> str:
    """Get friendly name (passing through to the underlying sensor)."""
    return f"{repr(self.sensor)} x [-{self.past}:+{self.future}]"

from_partial_sensor classmethod

from_partial_sensor(
    sensor: Callable[[str], Sensor[Sample, TMetadata]],
    collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
    past: int = 0,
    future: int = 0,
    parallel: int | None = None,
) -> Callable[[str], Window[SampleStack, Sample, TMetadata]]

Partially initialize from partially initialized sensor.

Use this to create windowed sensor constructors which can be applied to different traces to construct a dataset. For example, if you have a sensor_constructor:

sensor_constructor = ...
windowed_sensor_constructor = Window.from_partial_sensor(
    sensor_constructor, ...)

# ... somewhere inside the dataset constructor
sensor_instance = windowed_sensor_constructor(path_to_trace)

Parameters:

Name Type Description Default
sensor Callable[[str], Sensor[Sample, TMetadata]]

sensor constructor to wrap.

required
collate_fn Callable[[list[Sample]], SampleStack] | None

collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list.

None
past int

number of past samples, in addition to the current sample. Set to 0 to disable.

0
future int

number of future samples, in addition to the current sample. Set to 0 to disable.

0
parallel int | None

maximum number of samples to load in parallel; if None, all samples are loaded sequentially.

None
Source code in src/abstract_dataloader/generic/sequence.py
@classmethod
def from_partial_sensor(
    cls, sensor: Callable[[str], spec.Sensor[Sample, TMetadata]],
    collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
    past: int = 0, future: int = 0, parallel: int | None = None
) -> Callable[[str], "Window[SampleStack, Sample, TMetadata]"]:
    """Partially initialize from partially initialized sensor.

    Use this to create windowed sensor constructors which can be
    applied to different traces to construct a dataset. For example,
    if you have a `sensor_constructor`:

    ```python
    sensor_constructor = ...
    windowed_sensor_constructor = Window.from_partial_sensor(
        sensor_constructor, ...)

    # ... somewhere inside the dataset constructor
    sensor_instance = windowed_sensor_constructor(path_to_trace)
    ```

    Args:
        sensor: sensor *constructor* to wrap.
        collate_fn: collate function for aggregating a list of samples; if
            not specified, the samples are simply returned as a list.
        past: number of past samples, in addition to the current sample.
            Set to `0` to disable.
        future: number of future samples, in addition to the current
            sample. Set to `0` to disable.
        parallel: maximum number of samples to load in parallel; if `None`,
            all samples are loaded sequentially.
    """
    def create_wrapped_sensor(
        path: str
    ) -> Window[SampleStack, Sample, TMetadata]:
        return cls(
            sensor(path), collate_fn=collate_fn, past=past,
            future=future, parallel=parallel)

    return create_wrapped_sensor