abstract_dataloader.ext.lightning
¶
Dataloader / Pytorch Bridge.
Warning
Pytorch lightning must be
installed to use this module. This is not included in any extras; you will
need to pip install lightning
or add it to your dependencies.
The provided data module is based on the following assumptions:
- All splits use the same transform
Pipeline
, but each have a differentDataset
. This means that if any data augmentations are applied by the transform, theDataset
should pass somemeta
information (i.e., whether in training mode) as part of the data. - In-training visualizations are always rendered from the same set of a relatively small number of samples taken from the validation set.
- The same dataloader settings should be applied to all splits.
Info
Only sample-to-sample (.transform
) and sample-to-batch (.collate
)
transforms are applied in the dataloader; the training loop is
responsible for applying batch-to-batch (.forward
) transforms.
abstract_dataloader.ext.lightning.ADLDataModule
¶
Bases: LightningDataModule
, Generic[Raw, Transformed, Collated, Processed]
Pytorch dataloader wrapper for ADL-compliant datasets.
Info
Train/val/test splits are not all required to be present; if any are
not present, the corresponding .{split}_dataloader()
will raise an
error if called. Arbitrary split names are also allowed, though
train
, val
, and test
are expected for the
ADLDataModule.{train|val|test}_dataloader()
methods
expected by pytorch lightning.
Note
The underlying (transformed) dataset is cached (i.e. the same dataset object will be used on each call), but the dataloader container is not.
Type Parameters
Raw
: raw data loaded from the dataset.Transformed
: data following CPU-side transform.Collated
: data format after collation; should be in pytorch tensors.Processed
: data after GPU-side transform.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Mapping[str, Callable[[], Dataset[Raw]] | Dataset[Raw]]
|
datasets or dataset constructors for each split. |
required |
transforms
|
Pipeline[Raw, Transformed, Collated, Processed]
|
data transforms to apply. |
required |
batch_size
|
int
|
dataloader batch size. |
32
|
samples
|
int | Sequence[int]
|
number of validation-set samples to prefetch for
visualizations (or a list of indices to use). Note that these
samples are always held in memory! Set |
0
|
num_workers
|
int
|
number of worker processes during data loading and CPU-side processing. |
32
|
prefetch_factor
|
int
|
number of batches to fetch per worker. |
2
|
subsample
|
Mapping[str, int | float | None]
|
Sample only a (low-discrepancy) subset of samples on each split specified here instead of using all samples. |
{}
|
Attributes:
Name | Type | Description |
---|---|---|
transforms |
data transforms which should be applied to the data; in
particular, a |
Source code in src/abstract_dataloader/ext/lightning.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
|
samples
cached
property
¶
Validation samples for rendering samples.
If a simple samples: int
is specified, these samples are taken
uniformly len(val) // samples
apart with padding on either side.
Warning
While this property is cached, accessing this property the first time triggers a full load of the dataset validation split!
Returns:
Type | Description |
---|---|
Collated | None
|
Pre-loaded validation samples, nominally for generating
visualizations. If |
dataset
cached
¶
dataset(
split: Literal["train", "val", "test"] = "train",
) -> TransformedDataset[Raw, Transformed]
Get dataset for a given split, with sample transformation applied.
Info
If the a split is requested, and subsample
is specified for that
split, a subsample transform (via
SampledDataset
) is also
applied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
split
|
Literal['train', 'val', 'test']
|
target split. |
'train'
|
Returns:
Type | Description |
---|---|
TransformedDataset[Raw, Transformed]
|
Dataset for that split, using the partially bound constructor
passed to the |
Source code in src/abstract_dataloader/ext/lightning.py
from_traces
classmethod
¶
from_traces(
dataset: Callable[[Sequence[str]], Dataset[Raw]],
traces: Mapping[str, Sequence[str]],
transforms: Pipeline[Raw, Transformed, Collated, Processed],
**kwargs: dict[str, Any],
) -> ADLDataModule[Raw, Transformed, Collated, Processed]
Create from a dataset constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset
|
Callable[[Sequence[str]], Dataset[Raw]]
|
dataset constructor which takes a list of trace names and returns a dataset object. |
required |
traces
|
Mapping[str, Sequence[str]]
|
mapping of split names to trace names; the dataset constructor will be called with the trace names for each split. |
required |
transforms
|
Pipeline[Raw, Transformed, Collated, Processed]
|
data transforms to apply. |
required |
kwargs
|
dict[str, Any]
|
see the class constructor. |
{}
|
Source code in src/abstract_dataloader/ext/lightning.py
test_dataloader
¶
test_dataloader() -> DataLoader
Get test dataloader (shuffle=False, drop_last=False
).
train_dataloader
¶
train_dataloader() -> DataLoader
Get training dataloader (shuffle=True, drop_last=True
).
val_dataloader
¶
val_dataloader() -> DataLoader
Get validation dataloader (shuffle=False, drop_last=True
).