Skip to content

Generic Component Implementations

Generic "ready-to-use" implementations of common components.

Other generic and largely reusable components can be added to this submodule.

Note

Numpy (and jaxtyping) are the only dependencies; to keep the abstract_dataloader's dependencies lightweight and flexible, components should only be added here if they do not require any additional dependencies.

abstract_dataloader.generic.ComposedPipeline

Bases: Pipeline[TRaw, TTransformed, TCollated, TProcessed], Generic[TRaw, TRawInner, TTransformed, TCollated, TProcessedInner, TProcessed]

Compose pipeline sequentially with pre and post transforms.

Type Parameters
  • TRaw: initial input type.
  • TRawInner: output of the pre-composed transform, and input to the provided Pipeline.
  • TCollated, TProcessed: intermediate values for the provided Pipeline.
  • TProcessedInner: output of the transforms, and input to the post-composed transform.
  • TProcessed: output type.

Parameters:

Name Type Description Default
pipeline Pipeline[TRawInner, TTransformed, TCollated, TProcessedInner]

pipeline to compose.

required
pre Transform[TRaw, TRawInner] | None

pre-transform to apply on the CPU side; skipped if None.

None
post Transform[TProcessedInner, TProcessed] | None

post-transform to apply on the GPU side; skipped if None.

None
Source code in src/abstract_dataloader/generic/composition.py
class ComposedPipeline(
    spec.Pipeline[TRaw, TTransformed, TCollated, TProcessed],
    Generic[
        TRaw, TRawInner, TTransformed,
        TCollated, TProcessedInner, TProcessed]
):
    """Compose pipeline sequentially with pre and post transforms.

    Type Parameters:
        - `TRaw`: initial input type.
        - `TRawInner`: output of the pre-composed transform, and input to the
            provided [`Pipeline`][abstract_dataloader.spec].
        - `TCollated`, `TProcessed`: intermediate values for the provided
            [`Pipeline`][abstract_dataloader.spec].
        - `TProcessedInner`: output of the transforms, and input to the
            post-composed transform.
        - `TProcessed`: output type.

    Args:
        pipeline: pipeline to compose.
        pre: pre-transform to apply on the CPU side; skipped if `None`.
        post: post-transform to apply on the GPU side; skipped if `None`.
    """

    def __init__(
        self, pipeline: spec.Pipeline[
            TRawInner, TTransformed, TCollated, TProcessedInner],
        pre: spec.Transform[TRaw, TRawInner] | None = None,
        post: spec.Transform[TProcessedInner, TProcessed] | None = None
    ) -> None:
        self.pipeline = pipeline
        self.pre = pre
        self.post = post

        self.collate = pipeline.collate

    def sample(self, data: TRaw) -> TTransformed:
        """Transform single samples.

        Args:
            data: A single `TRaw` data sample.

        Returns:
            A single `TTransformed` data sample.
        """
        if self.pre is None:
            transformed = cast(TRawInner, data)
        else:
            transformed = self.pre(data)
        return self.pipeline.sample(transformed)

    def batch(self, data: TCollated) -> TProcessed:
        """Transform data batch.

        Args:
            data: A `TCollated` batch of data, nominally already sent to the
                GPU.

        Returns:
            The `TProcessed` output, ready for the downstream model.
        """
        transformed = self.pipeline.batch(data)
        if self.post is None:
            return cast(TProcessed, transformed)
        else:
            return self.post(transformed)

batch

batch(data: TCollated) -> TProcessed

Transform data batch.

Parameters:

Name Type Description Default
data TCollated

A TCollated batch of data, nominally already sent to the GPU.

required

Returns:

Type Description
TProcessed

The TProcessed output, ready for the downstream model.

Source code in src/abstract_dataloader/generic/composition.py
def batch(self, data: TCollated) -> TProcessed:
    """Transform data batch.

    Args:
        data: A `TCollated` batch of data, nominally already sent to the
            GPU.

    Returns:
        The `TProcessed` output, ready for the downstream model.
    """
    transformed = self.pipeline.batch(data)
    if self.post is None:
        return cast(TProcessed, transformed)
    else:
        return self.post(transformed)

sample

sample(data: TRaw) -> TTransformed

Transform single samples.

Parameters:

Name Type Description Default
data TRaw

A single TRaw data sample.

required

Returns:

Type Description
TTransformed

A single TTransformed data sample.

Source code in src/abstract_dataloader/generic/composition.py
def sample(self, data: TRaw) -> TTransformed:
    """Transform single samples.

    Args:
        data: A single `TRaw` data sample.

    Returns:
        A single `TTransformed` data sample.
    """
    if self.pre is None:
        transformed = cast(TRawInner, data)
    else:
        transformed = self.pre(data)
    return self.pipeline.sample(transformed)

abstract_dataloader.generic.Empty

Bases: Synchronization

Dummy synchronization which does not synchronize sensor pairs.

No samples will be registered, and the trace can only be used as a collection of sensors.

Source code in src/abstract_dataloader/generic/sync.py
class Empty(spec.Synchronization):
    """Dummy synchronization which does not synchronize sensor pairs.

    No samples will be registered, and the trace can only be used as a
    collection of sensors.
    """

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        return {k: np.array([], dtype=np.uint32) for k in timestamps}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    return {k: np.array([], dtype=np.uint32) for k in timestamps}

abstract_dataloader.generic.Metadata dataclass

Bases: Metadata

Generic metadata with timestamps.

Attributes:

Name Type Description
timestamps Float64[ndarray, N]

epoch timestamps.

Source code in src/abstract_dataloader/generic/sequence.py
@dataclass
class Metadata(spec.Metadata):
    """Generic metadata with timestamps.

    Attributes:
        timestamps: epoch timestamps.
    """

    timestamps: Float64[np.ndarray, "N"]

abstract_dataloader.generic.Nearest

Bases: Synchronization

Nearest sample synchronization, with respect to a reference sensor.

Applies the following:

  • Compute the midpoints between observations between each sensor.
  • Find which bin the reference sensor timestamps fall into.
  • Calculate the resulting time delta between timestamps. If this exceeds tol for any sensor-reference pair, remove this match.

See Synchronization for protocol details.

Parameters:

Name Type Description Default
reference str

reference sensor to synchronize to.

required
tol float

synchronization time tolerance, in seconds. Setting tol = np.inf works to disable this check altogether.

0.1
margin tuple[int | float, int | float]

time margin (in seconds; float) or index margin (in samples; int) to apply to the start and end time relative to the reference sensor, excluding samples within this margin.

(0, 0)
Source code in src/abstract_dataloader/generic/sync.py
class Nearest(spec.Synchronization):
    """Nearest sample synchronization, with respect to a reference sensor.

    Applies the following:

    - Compute the midpoints between observations between each sensor.
    - Find which bin the reference sensor timestamps fall into.
    - Calculate the resulting time delta between timestamps. If this exceeds
      `tol` for any sensor-reference pair, remove this match.

    See [`Synchronization`][abstract_dataloader.spec.] for protocol details.

    Args:
        reference: reference sensor to synchronize to.
        tol: synchronization time tolerance, in seconds. Setting `tol = np.inf`
            works to disable this check altogether.
        margin: time margin (in seconds; `float`) or index margin
            (in samples; `int`) to apply to the start and end time relative to
            the reference sensor, excluding samples within this margin.
    """

    def __init__(
        self, reference: str, tol: float = 0.1,
        margin: tuple[int | float, int | float] = (0, 0)
    ) -> None:
        if tol < 0:
            raise ValueError(
                f"Synchronization tolerance must be positive: {tol} < 0")

        self.tol = tol
        self.reference = reference
        self.margin = margin

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        try:
            t_ref = timestamps[self.reference]
        except KeyError:
            raise KeyError(
                f"Reference sensor {self.reference} was not provided in "
                f"timestamps, with keys: {list(timestamps.keys())}")

        if isinstance(self.margin[0], float):
            t_ref = t_ref[np.argmax(t_ref > t_ref[0] + self.margin[0]):]
        elif isinstance(self.margin[0], int) and self.margin[0] > 0:
            t_ref = t_ref[self.margin[0]:]

        if isinstance(self.margin[1], float):
            t_ref = t_ref[
                :-np.argmax((t_ref < t_ref[-1] - self.margin[1])[::-1])]
        elif isinstance(self.margin[1], int) and self.margin[1] > 0:
            t_ref = t_ref[:-self.margin[1]]

        indices = {
            k: np.searchsorted(
                (t_sensor[:-1] + t_sensor[1:]) / 2, t_ref
            ).astype(np.uint32)
            for k, t_sensor in timestamps.items()}

        valid = np.all(np.array([
           np.abs(timestamps[k][i_nearest] - t_ref) < self.tol
        for k, i_nearest in indices.items()]), axis=0)

        return {k: v[valid] for k, v in indices.items()}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    try:
        t_ref = timestamps[self.reference]
    except KeyError:
        raise KeyError(
            f"Reference sensor {self.reference} was not provided in "
            f"timestamps, with keys: {list(timestamps.keys())}")

    if isinstance(self.margin[0], float):
        t_ref = t_ref[np.argmax(t_ref > t_ref[0] + self.margin[0]):]
    elif isinstance(self.margin[0], int) and self.margin[0] > 0:
        t_ref = t_ref[self.margin[0]:]

    if isinstance(self.margin[1], float):
        t_ref = t_ref[
            :-np.argmax((t_ref < t_ref[-1] - self.margin[1])[::-1])]
    elif isinstance(self.margin[1], int) and self.margin[1] > 0:
        t_ref = t_ref[:-self.margin[1]]

    indices = {
        k: np.searchsorted(
            (t_sensor[:-1] + t_sensor[1:]) / 2, t_ref
        ).astype(np.uint32)
        for k, t_sensor in timestamps.items()}

    valid = np.all(np.array([
       np.abs(timestamps[k][i_nearest] - t_ref) < self.tol
    for k, i_nearest in indices.items()]), axis=0)

    return {k: v[valid] for k, v in indices.items()}

abstract_dataloader.generic.Next

Bases: Synchronization

Next sample synchronization, with respect to a reference sensor.

Applies the following:

  • Find the start time, defined by the earliest time which is observed by all sensors, and the end time, defined by the last time which is observed by all sensors.
  • Truncate the reference sensor's timestamps to this start and end time, and use this as the query timestamps.
  • For each time in the query, find the first sample from each sensor which is after this time.

See Synchronization for protocol details.

Parameters:

Name Type Description Default
reference str

reference sensor to synchronize to.

required
margin tuple[int | float, int | float]

time margin (in seconds; float) or index margin (in samples; int) to apply to the start and end, excluding samples within this margin.

(0, 0)
Source code in src/abstract_dataloader/generic/sync.py
class Next(spec.Synchronization):
    """Next sample synchronization, with respect to a reference sensor.

    Applies the following:

    - Find the start time, defined by the earliest time which is observed by
      all sensors, and the end time, defined by the last time which is observed
      by all sensors.
    - Truncate the reference sensor's timestamps to this start and end time,
      and use this as the query timestamps.
    - For each time in the query, find the first sample from each sensor which
      is after this time.

    See [`Synchronization`][abstract_dataloader.spec.] for protocol details.

    Args:
        reference: reference sensor to synchronize to.
        margin: time margin (in seconds; `float`) or index margin
            (in samples; `int`) to apply to the start and end, excluding
            samples within this margin.
    """

    def __init__(
        self, reference: str,
        margin: tuple[int | float, int | float] = (0, 0)
    ) -> None:
        self.reference = reference
        self.margin = margin

    def __call__(
        self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
    ) -> dict[str, UInt32[np.ndarray, "M"]]:
        """Apply synchronization protocol.

        Args:
            timestamps: input sensor timestamps.

        Returns:
            Synchronized index map.
        """
        try:
            ref_time_all = timestamps[self.reference]
        except KeyError:
            raise KeyError(
                f"Reference sensor {self.reference} was not provided in "
                f"timestamps, with keys: {list(timestamps.keys())}")

        start_time = max(t[0] for t in timestamps.values())
        end_time = min(t[-1] for t in timestamps.values())

        if isinstance(self.margin[0], float):
            start_time += self.margin[0]
        if isinstance(self.margin[1], float):
            end_time -= self.margin[1]

        start_idx = np.searchsorted(ref_time_all, start_time)
        end_idx = np.searchsorted(ref_time_all, end_time)

        if isinstance(self.margin[0], int):
            start_idx += self.margin[0]
        if isinstance(self.margin[1], int):
            end_idx -= self.margin[1]

        ref_time = ref_time_all[start_idx:end_idx]
        return {
            k: np.searchsorted(v, ref_time).astype(np.uint32)
            for k, v in timestamps.items()}

__call__

__call__(
    timestamps: dict[str, Float64[ndarray, _N]],
) -> dict[str, UInt32[ndarray, M]]

Apply synchronization protocol.

Parameters:

Name Type Description Default
timestamps dict[str, Float64[ndarray, _N]]

input sensor timestamps.

required

Returns:

Type Description
dict[str, UInt32[ndarray, M]]

Synchronized index map.

Source code in src/abstract_dataloader/generic/sync.py
def __call__(
    self, timestamps: dict[str, Float64[np.ndarray, "_N"]]
) -> dict[str, UInt32[np.ndarray, "M"]]:
    """Apply synchronization protocol.

    Args:
        timestamps: input sensor timestamps.

    Returns:
        Synchronized index map.
    """
    try:
        ref_time_all = timestamps[self.reference]
    except KeyError:
        raise KeyError(
            f"Reference sensor {self.reference} was not provided in "
            f"timestamps, with keys: {list(timestamps.keys())}")

    start_time = max(t[0] for t in timestamps.values())
    end_time = min(t[-1] for t in timestamps.values())

    if isinstance(self.margin[0], float):
        start_time += self.margin[0]
    if isinstance(self.margin[1], float):
        end_time -= self.margin[1]

    start_idx = np.searchsorted(ref_time_all, start_time)
    end_idx = np.searchsorted(ref_time_all, end_time)

    if isinstance(self.margin[0], int):
        start_idx += self.margin[0]
    if isinstance(self.margin[1], int):
        end_idx -= self.margin[1]

    ref_time = ref_time_all[start_idx:end_idx]
    return {
        k: np.searchsorted(v, ref_time).astype(np.uint32)
        for k, v in timestamps.items()}

abstract_dataloader.generic.ParallelPipelines

Bases: Pipeline[PRaw, PTransformed, PCollated, PProcessed]

Compose multiple transforms in parallel.

For example, with transforms {"radar": radar_tf, "lidar": lidar_tf, ...}, the composed transform performs:

{
    "radar": radar_tf.transform(data["radar"]),
    "lidar": lidar_tf.transform(data["lidar"]),
    ...
}

Note

This implies that the type parameters must be dict[str, Any], so this class is parameterized by a separate set of Composed(Raw|Transformed|Collated|Processed) types with this bound.

Tip

See torch.ParallelPipelines for an implementation which is compatible with nn.Module-based pipelines.

Type Parameters
  • PRaw, PTransformed, PCollated, PProcessed: see Pipeline.

Parameters:

Name Type Description Default
transforms Pipeline

transforms to compose. The key indicates the subkey to apply each transform to.

{}
Source code in src/abstract_dataloader/generic/composition.py
class ParallelPipelines(
    spec.Pipeline[PRaw, PTransformed, PCollated, PProcessed],
):
    """Compose multiple transforms in parallel.

    For example, with transforms `{"radar": radar_tf, "lidar": lidar_tf, ...}`,
    the composed transform performs:

    ```python
    {
        "radar": radar_tf.transform(data["radar"]),
        "lidar": lidar_tf.transform(data["lidar"]),
        ...
    }
    ```

    !!! note

        This implies that the type parameters must be `dict[str, Any]`, so this
        class is parameterized by a separate set of
        `Composed(Raw|Transformed|Collated|Processed)` types with this bound.

    !!! tip

        See [`torch.ParallelPipelines`][abstract_dataloader.] for an
        implementation which is compatible with [`nn.Module`][torch.]-based
        pipelines.

    Type Parameters:
        - `PRaw`, `PTransformed`, `PCollated`, `PProcessed`: see
          [`Pipeline`][abstract_dataloader.spec.].

    Args:
        transforms: transforms to compose. The key indicates the subkey to
            apply each transform to.
    """

    def __init__(self, **transforms: spec.Pipeline) -> None:
        self.transforms = transforms

    def sample(self, data: PRaw) -> PTransformed:
        return cast(
            PTransformed,
            {k: v.sample(data[k]) for k, v in self.transforms.items()})

    def collate(self, data: Sequence[PTransformed]) -> PCollated:
        return cast(PCollated, {
            k: v.collate([x[k] for x in data])
            for k, v in self.transforms.items()
        })

    def batch(self, data: PCollated) -> PProcessed:
        return cast(
            PProcessed,
            {k: v.batch(data[k]) for k, v in self.transforms.items()})

abstract_dataloader.generic.SequencePipeline

Bases: Pipeline[Sequence[TRaw], Sequence[TTransformed], Sequence[TCollated], Sequence[TProcessed]], Generic[TRaw, TTransformed, TCollated, TProcessed]

Transform which passes an additional sequence axis through.

The given Pipeline is modified to accept Sequence[...] for each data type in its pipeline, and return a list[...] across the additional axis, thus "passing through" the axis.

For example, suppose a sequence dataloader reads

[
    [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
    [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
    ...
    [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]

for sequence length t = 0...n and batch sample s = 0...b. For sequence length t, the output of the transforms will be batched with the sequence on the outside:

[
    Processed[s=0...b] [t=0],
    Processed[s=0...b] [t=1],
    ...
    Processed[s=0...b] [t=n]
]
Type Parameters
  • TRaw, TTransformed, TCollated, TProcessed: see Pipeline.

Parameters:

Name Type Description Default
pipeline Pipeline[TRaw, TTransformed, TCollated, TProcessed]

input pipeline.

required
Source code in src/abstract_dataloader/generic/sequence.py
class SequencePipeline(
    spec.Pipeline[
        Sequence[TRaw], Sequence[TTransformed],
        Sequence[TCollated], Sequence[TProcessed]],
    Generic[TRaw, TTransformed, TCollated, TProcessed]
):
    """Transform which passes an additional sequence axis through.

    The given `Pipeline` is modified to accept `Sequence[...]` for each
    data type in its pipeline, and return a `list[...]` across the additional
    axis, thus "passing through" the axis.

    For example, suppose a sequence dataloader reads

    ```
    [
        [Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
        [Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
        ...
        [Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
    ]
    ```

    for sequence length `t = 0...n` and batch sample `s = 0...b`. For sequence
    length `t`, the output of the transforms will be batched with the sequence
    on the outside:

    ```
    [
        Processed[s=0...b] [t=0],
        Processed[s=0...b] [t=1],
        ...
        Processed[s=0...b] [t=n]
    ]
    ```

    Type Parameters:
        - `TRaw`, `TTransformed`, `TCollated`, `TProcessed`: see
          [`Pipeline`][abstract_dataloader.spec.].

    Args:
        pipeline: input pipeline.
    """

    def __init__(
        self, pipeline: spec.Pipeline[
            TRaw, TTransformed, TCollated, TProcessed]
    ) -> None:
        self.pipeline = pipeline

    def sample(self, data: Sequence[TRaw]) -> list[TTransformed]:
        return [self.pipeline.sample(x) for x in data]

    def collate(
        self, data: Sequence[Sequence[TTransformed]]
    ) -> list[TCollated]:
        return [self.pipeline.collate(x) for x in zip(*data)]

    def batch(self, data: Sequence[TCollated]) -> list[TProcessed]:
        return [self.pipeline.batch(x) for x in data]

abstract_dataloader.generic.Window

Bases: Sensor[SampleStack, Metadata], Generic[SampleStack, Sample, TMetadata]

Load sensor data across a time window using a sensor transform.

Use this class as a generic transform to give time history to any sensor:

sensor =  ... # implements spec.Sensor
with_history = generic.Window(sensor, past=5, future=1, parallel=7)

In this example, 5 past samples, the current sample, and 1 future sample are loaded on every index:

with_history[i] = [
    sensor[i], sensor[i + 1], ... sensor[i + 5], sensor[i + 6]]
                                        ^
                            # timestamp for synchronization
Type Parameters
  • SampleStack: a collated series of consecutive samples. Can simply be list[Sample].
  • Sample: single observation sample type.
  • TMetadata: metadata type for the underlying sensor. Note that the Window wrapper doesn't actually have metadata type TMetadata; this type is just passed through from the sensor which is wrapped.

Parameters:

Name Type Description Default
sensor Sensor[Sample, TMetadata]

sensor to wrap.

required
collate_fn Callable[[list[Sample]], SampleStack] | None

collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list.

None
past int

number of past samples, in addition to the current sample. Set to 0 to disable.

0
future int

number of future samples, in addition to the current sample. Set to 0 to disable.

0
crop bool

if True, crop the first past and last future samples in the reported metadata to ensure that all samples are fully valid.

True
parallel int | None

maximum number of samples to load in parallel; if None, all samples are loaded sequentially.

None
Source code in src/abstract_dataloader/generic/sequence.py
class Window(
    abstract.Sensor[SampleStack, Metadata],
    Generic[SampleStack, Sample, TMetadata]
):
    """Load sensor data across a time window using a sensor transform.

    Use this class as a generic transform to give time history to any sensor:

    ```python
    sensor =  ... # implements spec.Sensor
    with_history = generic.Window(sensor, past=5, future=1, parallel=7)
    ```

    In this example, 5 past samples, the current sample, and 1 future sample
    are loaded on every index:

    ```python
    with_history[i] = [
        sensor[i], sensor[i + 1], ... sensor[i + 5], sensor[i + 6]]
                                            ^
                                # timestamp for synchronization
    ```

    Type Parameters:
        - `SampleStack`: a collated series of consecutive samples. Can simply be
            `list[Sample]`.
        - `Sample`: single observation sample type.
        - `TMetadata`: metadata type for the underlying sensor. Note that the
            `Window` wrapper doesn't actually have metadata type `TMetadata`;
            this type is just passed through from the sensor which is wrapped.

    Args:
        sensor: sensor to wrap.
        collate_fn: collate function for aggregating a list of samples; if not
            specified, the samples are simply returned as a list.
        past: number of past samples, in addition to the current sample. Set
            to `0` to disable.
        future: number of future samples, in addition to the current sample.
            Set to `0` to disable.
        crop: if `True`, crop the first `past` and last `future` samples in
            the reported metadata to ensure that all samples are fully valid.
        parallel: maximum number of samples to load in parallel; if `None`, all
            samples are loaded sequentially.
    """

    def __init__(
        self, sensor: spec.Sensor[Sample, TMetadata],
        collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
        past: int = 0, future: int = 0, crop: bool = True,
        parallel: int | None = None
    ) -> None:
        self.sensor = sensor
        self.past = past
        self.future = future
        self.parallel = parallel

        if collate_fn is None:
            collate_fn = cast(
                Callable[[list[Sample]], SampleStack], lambda x: x)
        self.collate_fn = collate_fn

        self.cropped = crop
        if crop:
            # hack for negative indexing
            _future = None if future == 0 else -future
            self.metadata = Metadata(
                timestamps=sensor.metadata.timestamps[past:_future])
        else:
            self.metadata = Metadata(timestamps=sensor.metadata.timestamps)

    @classmethod
    def from_partial_sensor(
        cls, sensor: Callable[[str], spec.Sensor[Sample, TMetadata]],
        collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
        past: int = 0, future: int = 0, crop: bool = True,
        parallel: int | None = None
    ) -> Callable[[str], "Window[SampleStack, Sample, TMetadata]"]:
        """Partially initialize from partially initialized sensor.

        Use this to create windowed sensor constructors which can be
        applied to different traces to construct a dataset. For example,
        if you have a `sensor_constructor`:

        ```python
        sensor_constructor = ...
        windowed_sensor_constructor = Window.from_partial_sensor(
            sensor_constructor, ...)

        # ... somewhere inside the dataset constructor
        sensor_instance = windowed_sensor_constructor(path_to_trace)
        ```

        Args:
            sensor: sensor *constructor* to wrap.
            collate_fn: collate function for aggregating a list of samples; if
                not specified, the samples are simply returned as a list.
            past: number of past samples, in addition to the current sample.
                Set to `0` to disable.
            future: number of future samples, in addition to the current
                sample. Set to `0` to disable.
            crop: if `True`, crop the first `past` and last `future` samples in
                the reported metadata to ensure that all samples are full
                valid.
            parallel: maximum number of samples to load in parallel; if `None`,
                all samples are loaded sequentially.
        """
        def create_wrapped_sensor(
            path: str
        ) -> Window[SampleStack, Sample, TMetadata]:
            return cls(
                sensor(path), collate_fn=collate_fn, past=past,
                future=future, crop=crop, parallel=parallel)

        return create_wrapped_sensor

    def __getitem__(self, index: int | np.integer) -> SampleStack:
        """Fetch measurements from this sensor, by index.

        !!! warning

            Note that `past` samples are lost at the beginning, and `future`
            samples at the end to account for the window size!

            If `crop=True`, these lost samples are taken into account by the
            `Window` wrapper; if `crop=False`, the caller must handle this.

        Args:
            index: sample index.

        Returns:
            A set of `past + 1 + future` consecutives samples. Note that there
                is a `past` offset of indices between the wrapped `Window` and
                the underlying sensor!

        Raises:
            IndexError: if `crop=False`, and the requested index is out of
                bounds (i.e., in the first `past` or last `future` samples).
        """
        if self.cropped:
            window = list(range(index, index + self.past + self.future + 1))
        else:
            window = list(range(index - self.past, index + self.future + 1))

        if window[0] < 0 or window[-1] >= len(self.sensor):
            raise IndexError(
                f"Requested invalid index {index} for uncropped "
                f"Window(past={self.past}, future={self.future}).")

        if self.parallel is not None:
            with ThreadPool(min(len(window), self.parallel)) as p:
                return self.collate_fn(p.map(self.sensor.__getitem__, window))
        else:
            return self.collate_fn(list(map(self.sensor.__getitem__, window)))

    def __repr__(self) -> str:
        """Get friendly name (passing through to the underlying sensor)."""
        return f"{repr(self.sensor)} x [-{self.past}:+{self.future}]"

__getitem__

__getitem__(index: int | integer) -> SampleStack

Fetch measurements from this sensor, by index.

Warning

Note that past samples are lost at the beginning, and future samples at the end to account for the window size!

If crop=True, these lost samples are taken into account by the Window wrapper; if crop=False, the caller must handle this.

Parameters:

Name Type Description Default
index int | integer

sample index.

required

Returns:

Type Description
SampleStack

A set of past + 1 + future consecutives samples. Note that there is a past offset of indices between the wrapped Window and the underlying sensor!

Raises:

Type Description
IndexError

if crop=False, and the requested index is out of bounds (i.e., in the first past or last future samples).

Source code in src/abstract_dataloader/generic/sequence.py
def __getitem__(self, index: int | np.integer) -> SampleStack:
    """Fetch measurements from this sensor, by index.

    !!! warning

        Note that `past` samples are lost at the beginning, and `future`
        samples at the end to account for the window size!

        If `crop=True`, these lost samples are taken into account by the
        `Window` wrapper; if `crop=False`, the caller must handle this.

    Args:
        index: sample index.

    Returns:
        A set of `past + 1 + future` consecutives samples. Note that there
            is a `past` offset of indices between the wrapped `Window` and
            the underlying sensor!

    Raises:
        IndexError: if `crop=False`, and the requested index is out of
            bounds (i.e., in the first `past` or last `future` samples).
    """
    if self.cropped:
        window = list(range(index, index + self.past + self.future + 1))
    else:
        window = list(range(index - self.past, index + self.future + 1))

    if window[0] < 0 or window[-1] >= len(self.sensor):
        raise IndexError(
            f"Requested invalid index {index} for uncropped "
            f"Window(past={self.past}, future={self.future}).")

    if self.parallel is not None:
        with ThreadPool(min(len(window), self.parallel)) as p:
            return self.collate_fn(p.map(self.sensor.__getitem__, window))
    else:
        return self.collate_fn(list(map(self.sensor.__getitem__, window)))

__repr__

__repr__() -> str

Get friendly name (passing through to the underlying sensor).

Source code in src/abstract_dataloader/generic/sequence.py
def __repr__(self) -> str:
    """Get friendly name (passing through to the underlying sensor)."""
    return f"{repr(self.sensor)} x [-{self.past}:+{self.future}]"

from_partial_sensor classmethod

from_partial_sensor(
    sensor: Callable[[str], Sensor[Sample, TMetadata]],
    collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
    past: int = 0,
    future: int = 0,
    crop: bool = True,
    parallel: int | None = None,
) -> Callable[[str], Window[SampleStack, Sample, TMetadata]]

Partially initialize from partially initialized sensor.

Use this to create windowed sensor constructors which can be applied to different traces to construct a dataset. For example, if you have a sensor_constructor:

sensor_constructor = ...
windowed_sensor_constructor = Window.from_partial_sensor(
    sensor_constructor, ...)

# ... somewhere inside the dataset constructor
sensor_instance = windowed_sensor_constructor(path_to_trace)

Parameters:

Name Type Description Default
sensor Callable[[str], Sensor[Sample, TMetadata]]

sensor constructor to wrap.

required
collate_fn Callable[[list[Sample]], SampleStack] | None

collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list.

None
past int

number of past samples, in addition to the current sample. Set to 0 to disable.

0
future int

number of future samples, in addition to the current sample. Set to 0 to disable.

0
crop bool

if True, crop the first past and last future samples in the reported metadata to ensure that all samples are full valid.

True
parallel int | None

maximum number of samples to load in parallel; if None, all samples are loaded sequentially.

None
Source code in src/abstract_dataloader/generic/sequence.py
@classmethod
def from_partial_sensor(
    cls, sensor: Callable[[str], spec.Sensor[Sample, TMetadata]],
    collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
    past: int = 0, future: int = 0, crop: bool = True,
    parallel: int | None = None
) -> Callable[[str], "Window[SampleStack, Sample, TMetadata]"]:
    """Partially initialize from partially initialized sensor.

    Use this to create windowed sensor constructors which can be
    applied to different traces to construct a dataset. For example,
    if you have a `sensor_constructor`:

    ```python
    sensor_constructor = ...
    windowed_sensor_constructor = Window.from_partial_sensor(
        sensor_constructor, ...)

    # ... somewhere inside the dataset constructor
    sensor_instance = windowed_sensor_constructor(path_to_trace)
    ```

    Args:
        sensor: sensor *constructor* to wrap.
        collate_fn: collate function for aggregating a list of samples; if
            not specified, the samples are simply returned as a list.
        past: number of past samples, in addition to the current sample.
            Set to `0` to disable.
        future: number of future samples, in addition to the current
            sample. Set to `0` to disable.
        crop: if `True`, crop the first `past` and last `future` samples in
            the reported metadata to ensure that all samples are full
            valid.
        parallel: maximum number of samples to load in parallel; if `None`,
            all samples are loaded sequentially.
    """
    def create_wrapped_sensor(
        path: str
    ) -> Window[SampleStack, Sample, TMetadata]:
        return cls(
            sensor(path), collate_fn=collate_fn, past=past,
            future=future, crop=crop, parallel=parallel)

    return create_wrapped_sensor