Pytorch Interoperability
¶
Pytorch-ADL wrappers.
These implementations provide interoperability with pytorch dataloaders, modules, etc.
Warning
This module is not automatically imported; you will need to explicitly import it:
Since pytorch is not declared as a required dependency, you will also need
to install torch
(or install the torch
extra with
pip install abstract_dataloader[torch]
).
Note
Recursive tree operations such as reshaping and stacking are performed
using the optree
library, or, if that is not present,
torch.utils._pytree
, which implements equivalent functionality. If
torch.utils._pytree
is removed in a later version, the constructor will
raise NotImplementedError
, and this fallback will need to be replaced.
Warning
Custom data container classes such as @dataclass
are only supported if
optree
is installed, and they are
registered with optree. However, dict
, list
,
tuple
, and equivalent types such as TypedDict
and NamedTuple
will
work out of the box.
abstract_dataloader.torch.Collate
¶
Bases: Collate[TTransformed, TCollated]
Generic numpy to pytorch collation.
Converts numpy arrays to pytorch tensors, and either stacks or concatenates each value.
Type Parameters
TTransformed
: input sample type.TCollated
: output collated type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
Literal['stack', 'concat']
|
whether to |
'concat'
|
Source code in src/abstract_dataloader/torch.py
abstract_dataloader.torch.ParallelPipelines
¶
Bases: Module
, ParallelPipelines[PRaw, PTransformed, PCollated, PProcessed]
Transform Compositions, modified for Pytorch compatibility.
Any nn.Module
transforms are registered to a separate
nn.ModuleDict
; the original .transforms
attribute is
maintained with references to the full pipeline.
See generic.ParallelPipelines
for more details about this implementation. .forward
and .__call__
should work as expected within pytorch.
Type Parameters
PRaw
,PTransformed
,PCollated
,PProcessed
: seePipeline
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
transforms
|
Pipeline
|
pipelines to compose. The key indicates the subkey to apply each transform to. |
{}
|
Source code in src/abstract_dataloader/torch.py
abstract_dataloader.torch.StackedSequencePipeline
¶
Bases: SequencePipeline[TRaw, TTransformed, TCollated, TProcessed]
Modify a transform to act on sequences.
Unlike the generic generic.SequencePipeline
implementation, this class places the sequence axis directly inside each
tensor, so that each data type has axes (batch, sequence, ...)
. For the
same input,
[
[Raw[s=0, t=0], Raw[s=0, t=1], ... Raw[s=0, t=n]]
[Raw[s=1, t=0], Raw[s=1, t=1], ... Raw[s=1, t=n]]
...
[Raw[s=b, t=0], Raw[s=b, t=1], ... Raw[s=b, t=n]
]
this transform instead yields
Info
This class requires that all outputs of .collate()
are pytorch
tensors. Furthermore, batches must be treated as an additional leading
axis by both .collate
and .forward
.
Warning
Since the output has an additional axis, it does not necessarily have the same type as the underlying transform!
This is accomplished by appropriately reshaping the data to use the batch-vectorized underlying implementation:
.transform
: apply the transform to each sample across the additional sequence axis..collate
: concatenate all sequences into a singlelist[Raw]
, instead of alist[list[Raw]]
. Then, collate the list, and reshape back intobatch sequence ...
order..transform
: flatten the collated data back to a(batch sequence) ...
single leading batch axis, apply the transform, and reshape back.
Type Parameters
PRaw
,PTransformed
,PCollated
,PProcessed
: seePipeline
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
transform
|
Pipeline[TRaw, TTransformed, TCollated, TProcessed]
|
pipeline to transform to accept sequences. |
required |
Source code in src/abstract_dataloader/torch.py
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
|
abstract_dataloader.torch.TransformedDataset
¶
Bases: Dataset[TTransformed]
, Generic[TRaw, TTransformed]
Pytorch-compatible dataset with transformation applied.
Extends torch.utils.data.Dataset
,
implementing a torch "map-style" dataset.
Type Parameters
TRaw
: raw data type from the dataloader.TTransformed
: output data type from the provided transform function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Dataset[TRaw]
|
source dataset. |
required |
transform
|
Transform[TRaw, TTransformed]
|
transformation to apply to each sample when loading (note
that |
required |