Pytorch Interoperability
¶
Pytorch-ADL wrappers.
These implementations are a superset of the generic
components, and provide interoperability with pytorch dataloaders, modules,
etc. For example, any Pipeline
-related
components which could contain pytorch nn.Module
s are
modified to subclass nn.Module
in order to properly register them.
Warning
This module is not automatically imported; you will need to explicitly import it:
Since pytorch is not declared as a required dependency, you will also need
to install torch
(or install the torch
extra with
pip install abstract_dataloader[torch]
).
Note
Recursive tree operations such as reshaping and stacking are performed
using the optree
library, or, if that is not present,
torch.utils._pytree
, which implements equivalent functionality. If
torch.utils._pytree
is removed in a later version, the constructor will
raise NotImplementedError
, and this fallback will need to be replaced.
Warning
Custom data container classes such as @dataclass
are only supported if
optree
is installed, and they are registered with optree. However, dict
,
list
, tuple
, and equivalent types such as TypedDict
and NamedTuple
will work out of the box.
abstract_dataloader.torch.Collate
¶
Bases: Collate[TTransformed, TCollated]
Generic numpy to pytorch collation.
Converts numpy arrays to pytorch tensors, and either stacks or concatenates each value.
Type Parameters
TTransformed
: input sample type.TCollated
: output collated type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
Literal['stack', 'concat']
|
whether to |
'concat'
|
Source code in src/abstract_dataloader/torch/torch.py
abstract_dataloader.torch.ComposedPipeline
¶
Bases: Module
, ComposedPipeline[TRaw, TRawInner, TTransformed, TCollated, TProcessedInner, TProcessed]
Compose pipeline sequentially with pre and post transforms.
Type Parameters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline
|
Pipeline[TRawInner, TTransformed, TCollated, TProcessedInner]
|
pipeline to compose. |
required |
pre
|
Transform[TRaw, TRawInner] | None
|
pre-transform to apply on the CPU side; skipped if |
None
|
post
|
Transform[TProcessedInner, TProcessed] | None
|
post-transform to apply on the GPU side; skipped if |
None
|
Source code in src/abstract_dataloader/torch/generic.py
abstract_dataloader.torch.Empty
¶
Bases: Synchronization
Dummy synchronization which does not synchronize sensor pairs.
No samples will be registered, and the trace can only be used as a collection of sensors.
Source code in src/abstract_dataloader/generic/sync.py
abstract_dataloader.torch.Metadata
dataclass
¶
abstract_dataloader.torch.Nearest
¶
Bases: Synchronization
Nearest sample synchronization, with respect to a reference sensor.
Applies the following:
- Compute the midpoints between observations between each sensor.
- Find which bin the reference sensor timestamps fall into.
- Calculate the resulting time delta between timestamps. If this exceeds
tol
for any sensor-reference pair, remove this match.
See Synchronization
for protocol details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference
|
str
|
reference sensor to synchronize to. |
required |
tol
|
float
|
synchronization time tolerance, in seconds. Setting |
0.1
|
Source code in src/abstract_dataloader/generic/sync.py
__call__
¶
Apply synchronization protocol.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timestamps
|
dict[str, Float64[ndarray, _N]]
|
input sensor timestamps. |
required |
Returns:
Type | Description |
---|---|
dict[str, UInt32[ndarray, M]]
|
Synchronized index map. |
Source code in src/abstract_dataloader/generic/sync.py
abstract_dataloader.torch.Next
¶
Bases: Synchronization
Next sample synchronization, with respect to a reference sensor.
Applies the following:
- Find the start time, defined by the earliest time which is observed by all sensors, and the end time, defined by the last time which is observed by all sensors.
- Truncate the reference sensor's timestamps to this start and end time, and use this as the query timestamps.
- For each time in the query, find the first sample from each sensor which is after this time.
See Synchronization
for protocol details.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reference
|
str
|
reference sensor to synchronize to. |
required |
Source code in src/abstract_dataloader/generic/sync.py
__call__
¶
Apply synchronization protocol.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timestamps
|
dict[str, Float64[ndarray, _N]]
|
input sensor timestamps. |
required |
Returns:
Type | Description |
---|---|
dict[str, UInt32[ndarray, M]]
|
Synchronized index map. |
Source code in src/abstract_dataloader/generic/sync.py
abstract_dataloader.torch.ParallelPipelines
¶
Bases: Module
, ParallelPipelines[PRaw, PTransformed, PCollated, PProcessed]
Transform Compositions, modified for Pytorch compatibility.
Any nn.Module
transforms are registered to a separate
nn.ModuleDict
; the original .transforms
attribute is
maintained with references to the full pipeline.
See generic.ParallelPipelines
for more details about this implementation. .forward
and .__call__
should work as expected within pytorch.
Type Parameters
PRaw
,PTransformed
,PCollated
,PProcessed
: seePipeline
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
transforms
|
Pipeline
|
pipelines to compose. The key indicates the subkey to apply each transform to. |
{}
|
Source code in src/abstract_dataloader/torch/generic.py
abstract_dataloader.torch.SequencePipeline
¶
Bases: Module
, SequencePipeline[TRaw, TTransformed, TCollated, TProcessed]
Transform which passes an additional sequence axis through.
The given Pipeline
is modified to accept Sequence[...]
for each
data type in its pipeline, and return a list[...]
across the additional
axis, thus "passing through" the axis.
For example, suppose a sequence dataloader reads
[
[Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
[Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
...
[Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]
for sequence length t = 0...n
and batch sample s = 0...b
. For sequence
length t
, the output of the transforms will be batched with the sequence
on the outside:
Type Parameters
TRaw
,TTransformed
,TCollated
,TProcessed
: seePipeline
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline
|
Pipeline[TRaw, TTransformed, TCollated, TProcessed]
|
input pipeline. |
required |
Source code in src/abstract_dataloader/torch/generic.py
abstract_dataloader.torch.StackedSequencePipeline
¶
Bases: Module
, Pipeline[Sequence[TRaw], Sequence[TTransformed], TCollated, TProcessed]
Modify a pipeline to act on sequences.
Unlike the generic generic.SequencePipeline
implementation, this class places the sequence axis directly inside each
tensor, so that each data type has axes (batch, sequence, ...)
. For the
same input,
[
[Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
[Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
...
[Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]
this pipeline instead yields
Info
This class requires that all outputs of .collate()
are pytorch
tensors. Furthermore, batches must be treated as an additional leading
axis by both .collate
and .forward
.
Warning
Since the output has an additional axis, it does not necessarily have the same type as the underlying transform!
This is accomplished by appropriately reshaping the data to use the batch-vectorized underlying implementation:
.sample
: apply the pipeline to each sample across the additional sequence axis..collate
: concatenate all sequences into a singlelist[Raw]
, instead of alist[list[Raw]]
. Then, collate the list, and reshape back intobatch sequence ...
order..batch
: flatten the collated data back to a(batch sequence) ...
single leading batch axis, apply the pipeline, and reshape back.
Type Parameters
PRaw
,PTransformed
,PCollated
,PProcessed
: seePipeline
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pipeline
|
Pipeline[TRaw, TTransformed, TCollated, TProcessed]
|
pipeline to transform to accept sequences. |
required |
Source code in src/abstract_dataloader/torch/torch.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
abstract_dataloader.torch.TransformedDataset
¶
Bases: Dataset[TTransformed]
, Generic[TRaw, TTransformed]
Pytorch-compatible dataset with transformation applied.
Extends torch.utils.data.Dataset
,
implementing a torch "map-style" dataset.
Type Parameters
TRaw
: raw data type from the dataloader.TTransformed
: output data type from the provided transform function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset[TRaw]
|
source dataset. |
required |
transform
|
Transform[TRaw, TTransformed]
|
transformation to apply to each sample when loading (note
that |
required |
Source code in src/abstract_dataloader/torch/torch.py
__getitem__
¶
abstract_dataloader.torch.Window
¶
Bases: Sensor[SampleStack, Metadata]
, Generic[SampleStack, Sample, TMetadata]
Load sensor data across a time window using a sensor transform.
Use this class as a generic transform to give time history to any sensor:
sensor = ... # implements spec.Sensor
with_history = generic.Window(sensor, past=5, future=1, parallel=7)
In this example, 5 past samples, the current sample, and 1 future sample are loaded on every index:
with_history[i] = [
sensor[i], sensor[i + 1], ... sensor[i + 5], sensor[i + 6]]
^
# timestamp for synchronization
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sensor
|
Sensor[Sample, TMetadata]
|
sensor to wrap. |
required |
collate_fn
|
Callable[[list[Sample]], SampleStack] | None
|
collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list. |
None
|
past
|
int
|
number of past samples, in addition to the current sample. Set
to |
0
|
future
|
int
|
number of future samples, in addition to the current sample.
Set to |
0
|
parallel
|
int | None
|
maximum number of samples to load in parallel; if |
None
|
Type Parameters
SampleStack
: a collated series of consecutive samples. Can simply belist[Sample]
.Sample
: single observation sample type.TMetadata
: metadata type for the underlying sensor. Note that theWindow
wrapper doesn't actually have metadata typeTMetadata
; this type is just passed through from the sensor which is wrapped.
Source code in src/abstract_dataloader/generic/sequence.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
__getitem__
¶
Fetch measurements from this sensor, by index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
index
|
int | integer
|
sample index; note that |
required |
Returns:
Type | Description |
---|---|
SampleStack
|
A set of |
Source code in src/abstract_dataloader/generic/sequence.py
from_partial_sensor
classmethod
¶
from_partial_sensor(
sensor: Callable[[str], Sensor[Sample, TMetadata]],
collate_fn: Callable[[list[Sample]], SampleStack] | None = None,
past: int = 0,
future: int = 0,
parallel: int | None = None,
) -> Callable[[str], Window[SampleStack, Sample, TMetadata]]
Partially initialize from partially initialized sensor.
Use this to create windowed sensor constructors which can be
applied to different traces to construct a dataset. For example,
if you have a sensor_constructor
:
sensor_constructor = ...
windowed_sensor_constructor = Window.from_partial_sensor(
sensor_constructor, ...)
# ... somewhere inside the dataset constructor
sensor_instance = windowed_sensor_constructor(path_to_trace)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sensor
|
Callable[[str], Sensor[Sample, TMetadata]]
|
sensor constructor to wrap. |
required |
collate_fn
|
Callable[[list[Sample]], SampleStack] | None
|
collate function for aggregating a list of samples; if not specified, the samples are simply returned as a list. |
None
|
past
|
int
|
number of past samples, in addition to the current sample.
Set to |
0
|
future
|
int
|
number of future samples, in addition to the current
sample. Set to |
0
|
parallel
|
int | None
|
maximum number of samples to load in parallel; if |
None
|