Skip to content

abstract_dataloader.ext.objective

Objective base classes and specifications.

Programming Model

  • An Objective is a callable which returns a (batched) scalar loss and a dictionary of metrics.
  • Objectives can be combined into a higher-order objective, MultiObjective, which combines their losses and aggregates their metrics; specify these objectives using a MultiObjectiveSpec.

abstract_dataloader.ext.objective.MultiObjective

Bases: Objective[TArray, YTrue, YPred]

Composite objective that combines multiple objectives.

Hydra Configuration

If using Hydra for dependency injection, a MultiObjective configuration should look like this:

objectives:
name:
    objective:
        _target_: ...
        kwargs: ...
    weight: 1.0
    y_true: "y_true_key"
    y_pred: "y_pred_key"
...

Type Parameters
  • YTrue: ground truth data type.
  • YHat: model output data type.

Parameters:

Name Type Description Default
objectives Mapping | MultiObjectiveSpec

multiple objectives, organized by name; see MultiObjectiveSpec. Each objective can also be provided as a dict, in which case the key/values are passed to MultiObjectiveSpec.

{}
Source code in src/abstract_dataloader/ext/objective.py
class MultiObjective(Objective[TArray, YTrue, YPred]):
    """Composite objective that combines multiple objectives.

    ??? example "Hydra Configuration"

        If using [Hydra](https://hydra.cc/docs/intro/) for dependency
        injection, a `MultiObjective` configuration should look like this:
        ```yaml
        objectives:
        name:
            objective:
                _target_: ...
                kwargs: ...
            weight: 1.0
            y_true: "y_true_key"
            y_pred: "y_pred_key"
        ...
        ```

    Type Parameters:
        - `YTrue`: ground truth data type.
        - `YHat`: model output data type.

    Args:
        objectives: multiple objectives, organized by name; see
            [`MultiObjectiveSpec`][^.]. Each objective can also be provided as
            a dict, in which case the key/values are passed to
            `MultiObjectiveSpec`.
    """

    def __init__(self, **objectives: Mapping | MultiObjectiveSpec) -> None:
        if len(objectives) == 0:
            raise ValueError("At least one objective must be provided.")

        self.objectives = {
            k: v if isinstance(v, MultiObjectiveSpec)
            else MultiObjectiveSpec(**v)
            for k, v in objectives.items()}

    def __call__(
        self, y_true: YTrue, y_pred: YPred, train: bool = True
    ) -> tuple[Float[TArray, "batch"], dict[str, Float[TArray, "batch"]]]:
        loss = 0.
        metrics = {}
        for k, v in self.objectives.items():
            k_loss, k_metrics = v.objective(
                v.index_y_true(y_true), v.index_y_pred(y_pred), train=train)
            loss += k_loss * v.weight

            for name, value in k_metrics.items():
                metrics[f"{k}/{name}"] = value

        # We assure that there's at least one objective.
        loss = cast(Float[TArray, ""] | Float[TArray, "batch"], loss)
        return loss, metrics

    def visualizations(
        self, y_true: YTrue, y_pred: YPred
    ) -> dict[str, UInt8[np.ndarray, "H W 3"]]:
        images = {}
        for k, v in self.objectives.items():
            k_images = v.objective.visualizations(
                v.index_y_true(y_true), v.index_y_pred(y_pred))
            for name, image in k_images.items():
                images[f"{k}/{name}"] = image
        return images

    def render(
        self, y_true: YTrue, y_pred: YPred, render_gt: bool = False
    ) -> dict[str, Shaped[np.ndarray, "batch ..."]]:
        rendered = {}
        for k, v in self.objectives.items():
            k_rendered = v.objective.render(
                v.index_y_true(y_true), v.index_y_pred(y_pred),
                render_gt=render_gt)
            for name, image in k_rendered.items():
                rendered[f"{k}/{name}"] = image
        return rendered

abstract_dataloader.ext.objective.MultiObjectiveSpec dataclass

Bases: Generic[YTrue, YPred, YTrueAll, YPredAll]

Specification for a single objective in a multi-objective setup.

The inputs and outputs for each objective are specified using y_true and y_pred:

  • None: The provided y_true and y_pred are passed directly to the objective. This means that if multiple objectives all use None, they will all receive the same data that comes from the dataloader.
  • str: The key indexes into a mapping which has the y_true/y_pred key, or an object which has a matching attribute.
  • Sequence[str]: Each key indexes into the layers of a nested mapping or object.
  • Callable: The callable is applied to the provided y_true and y_pred.

Warning

The user is responsible for ensuring that the y_true and y_pred keys or callables index the appropriate types for this objective.

Type Parameters
  • YTrue: objective ground truth data type.
  • YHat: objective model prediction data type.
  • YTrueAll: type of all ground truth data (as loaded by the dataloader).
  • YHatAll: type of all model output data (as produced by the model).

Attributes:

Name Type Description
objective Objective

The objective to use.

weight float

Weight of the objective in the overall loss.

y_true str | Sequence[str] | Callable[[YTrueAll], YTrue] | None

Key or callable to index into the ground truth data.

y_pred str | Sequence[str] | Callable[[YPredAll], YPred] | None

Key or callable to index into the model output data.

Source code in src/abstract_dataloader/ext/objective.py
@dataclass
class MultiObjectiveSpec(Generic[YTrue, YPred, YTrueAll, YPredAll]):
    """Specification for a single objective in a multi-objective setup.

    The inputs and outputs for each objective are specified using `y_true` and
    `y_pred`:

    - `None`: The provided `y_true` and `y_pred` are passed directly to the
        objective. This means that if multiple objectives all use `None`, they
        will all receive the same data that comes from the dataloader.
    - `str`: The key indexes into a mapping which has the `y_true`/`y_pred` key,
        or an object which has a matching attribute.
    - `Sequence[str]`: Each key indexes into the layers of a nested mapping or
        object.
    - `Callable`: The callable is applied to the provided `y_true` and `y_pred`.

    !!! warning

        The user is responsible for ensuring that the `y_true` and `y_pred`
        keys or callables index the appropriate types for this objective.

    Type Parameters:
        - `YTrue`: objective ground truth data type.
        - `YHat`: objective model prediction data type.
        - `YTrueAll`: type of all ground truth data (as loaded by the
            dataloader).
        - `YHatAll`: type of all model output data (as produced by the model).

    Attributes:
        objective: The objective to use.
        weight: Weight of the objective in the overall loss.
        y_true: Key or callable to index into the ground truth data.
        y_pred: Key or callable to index into the model output data.
    """

    objective: Objective
    weight: float = 1.0
    y_true: str | Sequence[str] | Callable[[YTrueAll], YTrue] | None = None
    y_pred: str | Sequence[str] | Callable[[YPredAll], YPred] | None = None

    def _index(
        self, data: Any, key: str | Sequence[str] | Callable | None
    ) -> Any:
        """Index into data using the key or callable."""
        def dereference(obj, k):
            if isinstance(obj, Mapping):
                try:
                    return obj[k]
                except KeyError as e:
                    raise KeyError(
                        f"Key {k} not found: {obj}") from e
            else:
                try:
                    return getattr(obj, k)
                except AttributeError as e:
                    raise AttributeError(
                        f"Attribute {k} not found: {obj}") from e

        if isinstance(key, str):
            return dereference(data, key)
        elif isinstance(key, Sequence):
            for k in key:
                data = dereference(data, k)
            return data
        elif callable(key):
            return key(data)
        else:   # key is None
            return data

    def index_y_true(self, y_true: YTrueAll) -> YTrue:
        """Get indexed ground truth data.

        Args:
            y_true: All ground truth data (as loaded by the dataloader).

        Returns:
            Indexed ground truth data.
        """
        return self._index(y_true, self.y_true)

    def index_y_pred(self, y_pred: YPredAll) -> YPred:
        """Get indexed model output data.

        Args:
            y_pred: All model output data (as produced by the model).

        Returns:
            Indexed model output data.
        """
        return self._index(y_pred, self.y_pred)

index_y_pred

index_y_pred(y_pred: YPredAll) -> YPred

Get indexed model output data.

Parameters:

Name Type Description Default
y_pred YPredAll

All model output data (as produced by the model).

required

Returns:

Type Description
YPred

Indexed model output data.

Source code in src/abstract_dataloader/ext/objective.py
def index_y_pred(self, y_pred: YPredAll) -> YPred:
    """Get indexed model output data.

    Args:
        y_pred: All model output data (as produced by the model).

    Returns:
        Indexed model output data.
    """
    return self._index(y_pred, self.y_pred)

index_y_true

index_y_true(y_true: YTrueAll) -> YTrue

Get indexed ground truth data.

Parameters:

Name Type Description Default
y_true YTrueAll

All ground truth data (as loaded by the dataloader).

required

Returns:

Type Description
YTrue

Indexed ground truth data.

Source code in src/abstract_dataloader/ext/objective.py
def index_y_true(self, y_true: YTrueAll) -> YTrue:
    """Get indexed ground truth data.

    Args:
        y_true: All ground truth data (as loaded by the dataloader).

    Returns:
        Indexed ground truth data.
    """
    return self._index(y_true, self.y_true)

abstract_dataloader.ext.objective.Objective

Bases: Protocol, Generic[TArray, YTrue, YPred]

Composable training objective.

Note

Metrics should use torch.no_grad() to make sure gradients are not computed for non-loss metrics!

Type Parameters
  • TArray: backend (jax.Array, torch.Tensor, etc.)
  • YTrue: ground truth data type.
  • YPred: model output data type.
Source code in src/abstract_dataloader/ext/objective.py
@runtime_checkable
class Objective(Protocol, Generic[TArray, YTrue, YPred]):
    """Composable training objective.

    !!! note

        Metrics should use `torch.no_grad()` to make sure gradients are not
        computed for non-loss metrics!

    Type Parameters:
        - `TArray`: backend (`jax.Array`, `torch.Tensor`, etc.)
        - `YTrue`: ground truth data type.
        - `YPred`: model output data type.
    """

    @abstractmethod
    def __call__(
        self, y_true: YTrue, y_pred: YPred, train: bool = True
    ) -> tuple[Float[TArray, "batch"], dict[str, Float[TArray, "batch"]]]:
        """Training metrics implementation.

        Args:
            y_true: data channels (i.e. dataloader output).
            y_pred: model outputs.
            train: Whether in training mode (i.e. skip expensive metrics).

        Returns:
            A tuple containing the loss and a dict of metric values.
        """
        ...

    def visualizations(
        self, y_true: YTrue, y_pred: YPred
    ) -> dict[str, UInt8[np.ndarray, "H W 3"]]:
        """Generate visualizations for each entry in a batch.

        This method may return an empty dict.

        !!! note

            This method should be called only from a "detached" CPU thread so
            as not to affect training throughput; the caller is responsible for
            detaching gradients and sending the data to the CPU. As such,
            implementations are free to use CPU-specific methods.

        Args:
            y_true: data channels (i.e., dataloader output).
            y_pred: model outputs.

        Returns:
            A dict, where each key is the name of a visualization, and the
                value is a stack of RGB images in HWC order, detached from
                Torch and sent to a numpy array.
        """
        ...

    def render(
        self, y_true: YTrue, y_pred: YPred, render_gt: bool = False
    ) -> dict[str, Shaped[np.ndarray, "batch ..."]]:
        """Render model outputs and/or ground truth for later analysis.

        This method may return an empty dict.

        ??? question "How does this differ from `visualizations`?"

            Unlike `visualizations`, which is expected to return a single
            RGB image per batch, `render` is:

            - expected to return a unique rendered value per sample, and
            - may have arbitrary types (as long as they are a numpy arrays).

        Args:
            y_true: data channels (i.e. dataloader output).
            y_pred: model outputs.
            render_gt: whether to render ground truth data.

        Returns:
            A dict, where each key is the name of a rendered output, and the
                value is a numpy array of the rendered data (e.g., an image).
        """
        ...

__call__ abstractmethod

__call__(
    y_true: YTrue, y_pred: YPred, train: bool = True
) -> tuple[Float[TArray, batch], dict[str, Float[TArray, batch]]]

Training metrics implementation.

Parameters:

Name Type Description Default
y_true YTrue

data channels (i.e. dataloader output).

required
y_pred YPred

model outputs.

required
train bool

Whether in training mode (i.e. skip expensive metrics).

True

Returns:

Type Description
tuple[Float[TArray, batch], dict[str, Float[TArray, batch]]]

A tuple containing the loss and a dict of metric values.

Source code in src/abstract_dataloader/ext/objective.py
@abstractmethod
def __call__(
    self, y_true: YTrue, y_pred: YPred, train: bool = True
) -> tuple[Float[TArray, "batch"], dict[str, Float[TArray, "batch"]]]:
    """Training metrics implementation.

    Args:
        y_true: data channels (i.e. dataloader output).
        y_pred: model outputs.
        train: Whether in training mode (i.e. skip expensive metrics).

    Returns:
        A tuple containing the loss and a dict of metric values.
    """
    ...

render

render(
    y_true: YTrue, y_pred: YPred, render_gt: bool = False
) -> dict[str, Shaped[ndarray, "batch ..."]]

Render model outputs and/or ground truth for later analysis.

This method may return an empty dict.

How does this differ from visualizations?

Unlike visualizations, which is expected to return a single RGB image per batch, render is:

  • expected to return a unique rendered value per sample, and
  • may have arbitrary types (as long as they are a numpy arrays).

Parameters:

Name Type Description Default
y_true YTrue

data channels (i.e. dataloader output).

required
y_pred YPred

model outputs.

required
render_gt bool

whether to render ground truth data.

False

Returns:

Type Description
dict[str, Shaped[ndarray, 'batch ...']]

A dict, where each key is the name of a rendered output, and the value is a numpy array of the rendered data (e.g., an image).

Source code in src/abstract_dataloader/ext/objective.py
def render(
    self, y_true: YTrue, y_pred: YPred, render_gt: bool = False
) -> dict[str, Shaped[np.ndarray, "batch ..."]]:
    """Render model outputs and/or ground truth for later analysis.

    This method may return an empty dict.

    ??? question "How does this differ from `visualizations`?"

        Unlike `visualizations`, which is expected to return a single
        RGB image per batch, `render` is:

        - expected to return a unique rendered value per sample, and
        - may have arbitrary types (as long as they are a numpy arrays).

    Args:
        y_true: data channels (i.e. dataloader output).
        y_pred: model outputs.
        render_gt: whether to render ground truth data.

    Returns:
        A dict, where each key is the name of a rendered output, and the
            value is a numpy array of the rendered data (e.g., an image).
    """
    ...

visualizations

visualizations(
    y_true: YTrue, y_pred: YPred
) -> dict[str, UInt8[ndarray, "H W 3"]]

Generate visualizations for each entry in a batch.

This method may return an empty dict.

Note

This method should be called only from a "detached" CPU thread so as not to affect training throughput; the caller is responsible for detaching gradients and sending the data to the CPU. As such, implementations are free to use CPU-specific methods.

Parameters:

Name Type Description Default
y_true YTrue

data channels (i.e., dataloader output).

required
y_pred YPred

model outputs.

required

Returns:

Type Description
dict[str, UInt8[ndarray, 'H W 3']]

A dict, where each key is the name of a visualization, and the value is a stack of RGB images in HWC order, detached from Torch and sent to a numpy array.

Source code in src/abstract_dataloader/ext/objective.py
def visualizations(
    self, y_true: YTrue, y_pred: YPred
) -> dict[str, UInt8[np.ndarray, "H W 3"]]:
    """Generate visualizations for each entry in a batch.

    This method may return an empty dict.

    !!! note

        This method should be called only from a "detached" CPU thread so
        as not to affect training throughput; the caller is responsible for
        detaching gradients and sending the data to the CPU. As such,
        implementations are free to use CPU-specific methods.

    Args:
        y_true: data channels (i.e., dataloader output).
        y_pred: model outputs.

    Returns:
        A dict, where each key is the name of a visualization, and the
            value is a stack of RGB images in HWC order, detached from
            Torch and sent to a numpy array.
    """
    ...

abstract_dataloader.ext.objective.VisualizationConfig dataclass

General-purpose visualization configuration.

Objectives which make use of this configuration may ignore the provided values.

Attributes:

Name Type Description
cols int

number of columns to tile images for in-training visualizations.

width int

width of each sample when rendered.

height int

height of each sample when rendered.

cmaps Mapping[str, str | UInt8[ndarray, 'N 3']]

colormaps to use, where values correspond to the name of a matplotlib colormap or a numpy array of enumerated RGB values.

Source code in src/abstract_dataloader/ext/objective.py
@dataclass(frozen=True)
class VisualizationConfig:
    """General-purpose visualization configuration.

    Objectives which make use of this configuration may ignore the provided
    values.

    Attributes:
        cols: number of columns to tile images for in-training visualizations.
        width: width of each sample when rendered.
        height: height of each sample when rendered.
        cmaps: colormaps to use, where values correspond to the name of a
            matplotlib colormap or a numpy array of enumerated RGB values.
    """

    cols: int = 8
    width: int = 512
    height: int = 256
    cmaps: Mapping[
        str, str | UInt8[np.ndarray, "N 3"]] = field(default_factory=dict)