OnlineMetrics

OnlineMetrics.OnlineMetricsModule

OnlineMetrics

Stable Dev Build Status Coverage

OnlineMetrics provides metrics for online, streaming, and batched machine learning workflows. Metrics implement a small, consistent interface so they can be updated batch-by-batch, merged across devices/processes, and queried for current values with minimal overhead.

Features

  • Compute metrics in batches - no need to store full datasets.
  • Track both indivudual and multiple metrics.
  • Mergeable states for distributed/parallel aggregation.
  • Pluggable data formats (e.g., OneHot) with validation and formatting hooks.
  • Small, explicit API surface for easy extension.

Design philosophy

  • Minimal, explicit interface: Implement AbstractMetric and the handful of required functions (name, initial_state, batch_state, merge_state, metric_value) to add a new metric.
  • Robust input handling: Metrics declare or accept data formats; validation and formatting are separated from core metric logic.
  • Usability: Friendly Base.show representations and MetricCollection tree-printing for quick inspection.

Quickstart

Install:

using Pkg
Pkg.add("OnlineMetrics")

Basic usage:

using OnlineMetrics

m = Metric(Accuracy(2))                           # track a single metric
mc = MetricCollection(Accuracy(2), Precision(2))  # track multiple metrics

step!(m, [0,1,1,0], [0,1,0,0])                    # update with a batch
step!(mc, [0,1,1,0], [0,1,0,0])                   # update collection

value(m)                                          # current value of single metric
value(mc)                                         # NamedTuple of values for collection

Merging states (useful for distributed runs):

m1 = Metric(Accuracy(2)); step!(m1, preds1, labels1)
m2 = Metric(Accuracy(2)); step!(m2, preds2, labels2)
m_merged = merge(m1, m2)
value(m_merged)

Examples

See the src folder for metric implementations and refer to the docs for usage examples.

Development & testing

Run tests from the package root:

julia --project=. -e 'using Pkg; Pkg.test()'

Or in the package REPL:

] test OnlineMetrics

Contributing

Contributions welcome. Please open issues or pull requests for bug fixes, new metrics, or API improvements. Follow existing style in src and add tests to runtests.jl.

License

Distributed under the terms of the MIT license. See the LICENSE file.

source

Metric Tracking

OnlineMetrics.MetricType
Metric(m::AbstractMetric; name=name(m))

An object to track a single metric and its state over multiple mini-batches.

Parameters

  • m::AbstractMetric: A metric to be tracked.
  • name::String: An optional name for the metric. Defaults to the name of the metric type.

Example

julia> m = Metric(Accuracy(2); name="My Accuracy Metric");

julia> step!(m, [0,1,0,1], [0,1,1,1]);

julia> value(m)
0.75
source
OnlineMetrics.MetricCollectionType
MetricCollection(metrics...)

An object to track one or more metrics concurrently.

Parameters

  • metrics...: A variable number of metrics to be tracked. Each metric can be provided either as an instance of a subtype of AbstractMetric or as a Pair{String, AbstractMetric} to override the default name of the metric.

Example

julia> m = MetricCollection(Accuracy(2), "precision" => Precision(2, agg=nothing), mIoU(2))
MetricCollection
├─ accuracy: value=0.0 | correct=0 | n=0
├─ precision: value=[1.0, 1.0] | tp=[0, 0] | fp=[0, 0]
└─ mIoU: value=1.0 | intersection=[0, 0] | union=[0, 0]


julia> step!(m, [0, 0, 1, 0], [0, 0, 1, 1])
MetricCollection
├─ accuracy: value=0.75 | correct=3 | n=4
├─ precision: value=[0.666667, 1.0] | tp=[2, 1] | fp=[1, 0]
└─ mIoU: value=0.583333 | intersection=[2, 1] | union=[3, 2]
source
OnlineMetrics.step!Function
step!(x::Metric, y_pred, y_true)
step!(x::MetricCollection, y_pred, y_true)

Update the metric(s) with a new batch of predictions and true labels.

source
OnlineMetrics.mergeFunction
merge(metrics...)

Merge muliple Metric or MetricCollection objects into a single object.

Useful for aggregating metrics across multiple devices or processes.

source

Metric Interface

OnlineMetrics.AbstractMetricType

Metrics are measures of a model's performance, such as loss, accuracy, or squared error.

Metrics are updated incrementally as new data arrives, making them suitable for online learning scenarios.

Each metric must implement the following interface:

  • name: Returns the human-readable name of the metric.
  • initial_state: Returns the initial state of the metric.
  • batch_state: Computes the metric's state for a single batch of predictions and labels.
  • merge_state: Merges two metric states into a single state.
  • metric_value: Computes the metric's value from its current state.

Optional Methods

  • data_format: Returns the data format expected by the metric, or nothing if no specific format is required. Defaults to nothing.
source
OnlineMetrics.ClassificationMetricType

Classification metrics are used to evaluate the performance of models that predict a discrete label for each observation. Subtypes of ClassificationMetric have a default data format of OneHot, which provides batch_state with Matrix{Bool} inputs representing one-hot encoded class labels.

source
OnlineMetrics.batch_stateFunction
batch_state(m::AbstractMetric, y_pred, y_true) -> state

Compute the metric's state for a single batch of predictions and labels.

The type of y_pred and y_true should not be specialized. Data formatting and validaton should instead handled by the format and validate methods defined by the metric's data format. This ensures that users receive more informative error messages when providing invalid data.

source
OnlineMetrics.data_formatFunction
data_format(m::AbstractMetric) -> Union{AbstractDataFormat, Nothing}

Return the data format expected by the metric, or nothing if no specific format is required.

source
OnlineMetrics.stepFunction
step(m::AbstractMetric, y_pred, y_true, oldstate) -> newstate

Update the metric state for the given batch of labels and predictions.

Parameters

  • m::AbstractMetric: The metric to be updated.
  • y_pred: The model predictions for the current batch.
  • y_true: The true labels for the current batch.
  • oldstate: The previous state of the metric.

Returns

  • newstate: The updated state of the metric.
source

Classification Metrics

OnlineMetrics.AccuracyType
Accuracy(nclasses::Int)

Measures the model's overall accuracy as correct / n.

Parameters

  • nclasses::Int: The number of classes for the classification task.
source
OnlineMetrics.IoUType
IoU(nclasses::Int; agg=nothing)

Intersection over Union (IoU) is a measure of the overlap between a prediction and a label.

Parameters

  • nclasses::Int: The number of classes for the classification task.

Keyword Parameters

  • agg: Specifies the type of IoU aggregation to be computed. The possible values are:
    • :mean: Calculates mean IoU (mIoU) by averaging the IoU across all classes.
    • nothing: Calculates the per-class IoU, which is returned as a Vector with the same length as classes.
source
OnlineMetrics.mIoUFunction
mIoU(nclasses::Int)

A convenience constructor for IoU that defaults to mean aggregation.

source
OnlineMetrics.PrecisionType
Precision(nclasses::Int; agg=:macro)

Precision is the ratio of true positives to the sum of true positives and false positives, measuring the accuracy of positive predictions.

Parameters

  • nclasses::Int: The number of classes for the classification task.

Keyword Parameters

  • agg: Specifies the type of precision aggregation to be computed. The possible values are:
    • :macro: Calculates macro-averaged precision, which computes the precision for each class independently and then takes the average.
    • :micro: Calculates micro-averaged precision, which aggregates the contributions of all classes to compute a single precision value.
    • nothing: Calculates the per-class precision, which is returned as a Vector with the same length as classes.
source
OnlineMetrics.RecallType
Recall(nclasses::Int; agg=:macro)

Recall, also known as sensitivity or true positive rate, is the ratio of true positives to the sum of true positives and false negatives, measuring the ability of the classifier to identify all positive instances.

Parameters

  • nclasses::Int: The number of classes for the classification task.

Keyword Parameters

  • agg: Specifies the type of recall aggregation to be computed. The possible values are:
    • :macro: Calculates macro-averaged recall, which computes the recall for each class independently and then takes the average.
    • :micro: Calculates micro-averaged recall, which aggregates the contributions of all classes to compute a single recall value.
    • :nothing: Calculates the per-class recall, which is returned as a Vector with the same length as classes.
source
OnlineMetrics.F1ScoreType
F1Score(nclasses::Int; agg=:macro)

F1 Score is the harmonic mean of precision and recall, providing a single metric that balances both concerns.

Parameters

  • nclasses::Int: The number of classes for the classification task.

Keyword Parameters

  • agg: Specifies the type of F1 score aggregation to be computed. The possible values are:
    • :macro: Calculates macro-averaged F1 score, which computes the F1 score for each class independently and then takes the average.
    • :micro: Calculates micro-averaged F1 score, which aggregates the contributions of all classes to compute a single F1 score.
    • nothing: Calculates the per-class F1 score, which is returned as a Vector with the same length as classes.
source
OnlineMetrics.ConfusionMatrixType
ConfusionMatrix(nclasses::Int)

Calculate the confusion matrix over two or more classes. The columns of the resulting nclasses x nclasses matrix correspond to the true label while the rows correspond to the prediction.

Parameters

  • nclasses::Int: The number of possible classes in the classification task.
source

Other Metrics

OnlineMetrics.AverageMeasureType
AverageMeasure(measure::Function, name::String)

Tracks the average value of a given measure over mini-batches. Typically used to track loss functions or regression metrics.

source

Data Formats

OnlineMetrics.OneHotType
OneHot(nclasses::Int)

A data format consisting of one-hot encoded class labels for nclasses classes.

Input

  • If the input data x is an array of shape (D...,1,N) or (N,), it is interpreted as class logits in the range [0, nclasses-1].
  • If the input data x is an array of shape (D...,nclasses,N), it is interpreted as one-hot encoded vectors.

Output

The output is a Matrix{Bool} of shape (nclasses, N).

source
OnlineMetrics.formatFunction
format(df::Nothing, x::Array{<:Real})
format(df::AbstractDataFormat, x::Array{<:Real})

Format the input data x according to the specified data format df.

Parameters

  • df: An instance of a subtype of AbstractDataFormat or nothing.
  • x: An array of real-valued data to be formatted.

Returns

The formatted data.

source
OnlineMetrics.validateFunction
validate(df::Nothing, x::Array{<:Real})
validate(df::AbstractDataFormat, x::Array{<:Real})

Validate that the input data x conforms to the specified data format df.

Raises an ArgumentError if the data does not conform.

Parameters

  • df: An instance of a subtype of AbstractDataFormat or nothing.
  • x: An array of real-valued data to be validated.
source

Index