Index
MixtureDensityNetworks.MDNMixtureDensityNetworks.MDNMixtureDensityNetworks.MixtureDensityNetworkMixtureDensityNetworks.MixtureDensityNetworkMixtureDensityNetworks.MultivariateGMMMixtureDensityNetworks.MultivariateGMMMixtureDensityNetworks.UnivariateGMMMixtureDensityNetworks.UnivariateGMMMixtureDensityNetworks.fit!MixtureDensityNetworks.generate_dataMixtureDensityNetworks.likelihood_lossMixtureDensityNetworks.likelihood_lossMixtureDensityNetworks.predict_mode
API
MixtureDensityNetworks.MDN — TypeMDNA model type for constructing a Mixture Density Network, based on MixtureDensityNetworks.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
MDN = @load MDN pkg=MixtureDensityNetworksDo model = MDN() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in MDN(mixtures=...).
A neural network which parameterizes a Gaussian Mixture Model (GMM) distributed over the target varible y conditioned on the features X.
Training Data
In MLJ or MLJBase, bind an MDN instance model to data with mach = machine(model, X, y) where
X: any table of input features (eg, aDataFrame) whose columns belong to theContinuousscitypes`.y: the target, which can be anyAbstractVectorwhose element scitype isContinuous.
Hyperparameters
mixtures=5: number of Gaussian mixtures in the predicted distributionlayers=[128,]: hidden layer topology, starting from the first hidden layerη=1e-3: learning rate used for the optimizerepochs=1: number of epochs to train the modelbatchsize=32: batch size used during training
Operations
predict(mach, Xnew): return the distributions over the target conditioned on the new featuresXnewhaving the same scitype asXabove.predict_mode(mach, Xnew): return the largest modes of the distributions over targets conditioned on the new featuresXnewhaving the same scitype asXabove.predict_mean(mach, Xnew): return the means of the distributions over targets conditioned on the new featuresXnewhaving the same scitype asXabove.predict_median(mach, Xnew): return the medians of the distributions over targets conditioned on the new featuresXnewhaving the same scitype asXabove.
Fitted Parameters
The fields of fitted_params(mach) are:
fitresult: the trained mixture density model, compatible with the Flux ecosystem.
Report
learning_curve: the average training loss for each epoch.best_epoch: the epoch (starting from 1) with the lowest training loss.best_loss: the best (lowest) loss encountered durind training. Corresponds to the average loss of the best epoch.
Accessor Functions
training_losses(mach)returns the learning curve as a vector of average training losses for each epoch.
Examples
using MLJ
MDN = @load MDN pkg=MixtureDensityNetworks
mdn = MDN(mixtures=12, epochs=100, layers=[512, 256, 128])
X, y = make_regression(100, 1) # synthetic data
mach = machine(mdn, X, y) |> fit!
Xnew, _ = make_regression(3, 1)
ŷ = predict(mach, Xnew) # new predictions
report(mach).best_epoch # best epoch encountered during training
report(mach).best_loss # best loss encountered during training
training_losses(mach) # learning curveMixtureDensityNetworks.MDN — MethodMDN(; mixtures=5, layers=[128], η=1e-3, epochs=1, batchsize=32)Defines an MDN model with the given hyperparameters.
Parameters
mixtures: The number of gaussian mixtures to use in estimating the conditional distribution (default=5).layers: A vector indicating the number of nodes in each of the hidden layers (default=[128,]).η: The learning rate to use when training the model (default=1e-3).epochs: The number of epochs to train the model (default=1).batchsize: The batchsize to use during training (default=32).
MixtureDensityNetworks.MixtureDensityNetwork — Typestruct MixtureDensityNetwork{T}A Flux model for implementing a standard Mixture Density Network.
Parameters
hidden::Flux.Chainoutput::Any
MixtureDensityNetworks.MixtureDensityNetwork — MethodMixtureDensityNetwork(
input::Int64,
output::Int64,
layers::Vector{Int64},
mixtures::Int64
) -> Union{MixtureDensityNetwork{MultivariateGMM}, MixtureDensityNetwork{UnivariateGMM}}
Construct a standard Mixture Density Network.
Parameters
input: The dimension of the input features.output: The dimension of the output. Setting output = 1 indicates a univariate model, whereas output > 1 indicates a multivariate model.layers: The topolgy of the hidden layers, starting from the first layer.mixtures: The number of Gaussian mixtures to use in the predicted distribution.
MixtureDensityNetworks.MultivariateGMM — Typestruct MultivariateGMMA layer which produces a multivariate Gaussian Mixture Model as its output.
Parameters
outputs::Int64mixtures::Int64μ::Flux.DenseΣ::Flux.Densew::Flux.Chain
MixtureDensityNetworks.MultivariateGMM — MethodMultivariateGMM(
input::Int64,
output::Int64,
mixtures::Int64
) -> MultivariateGMM
Construct a layer which returns a multivariate Gaussian Mixture Model as its output.
Parameters
input: Specifies the length of the feature vectors. The layer expects a matrix with the dimensionsinput x Nas input.output: Specifies the length of the label vectors. The layer returns a matrix with dimensionsoutput x Nas output.mixtures: The number of mixtures to use in the GMM.
MixtureDensityNetworks.UnivariateGMM — Typestruct UnivariateGMMA layer which produces a univariate Gaussian Mixture Model as its output.
Parameters
μ::Flux.DenseΣ::Flux.Densew::Flux.Chain
MixtureDensityNetworks.UnivariateGMM — MethodUnivariateGMM(
input::Int64,
mixtures::Int64
) -> UnivariateGMM
Construct a layer which returns a univariate Gaussian Mixture Model as its output.
Parameters
input: Specifies the length of the feature vectors. The layer expects a matrix with the dimensionsinput x Nas input.mixtures: The number of mixtures to use in the GMM.
MixtureDensityNetworks.fit! — Methodfit!(
m,
X::Matrix{<:Real},
Y::Matrix{<:Real};
opt,
batchsize,
epochs,
verbosity
) -> Tuple{Any, NamedTuple{(:learning_curve, :best_epoch, :best_loss), Tuple{Vector{Float64}, Int64, Float64}}}
Fit the model to the data given by X and Y.
Parameters
m: The model to be trained.X: A dxn matrix where d is the number of input features and n is the number of samples.Y: A dxn matrix where d is the dimension of the output and n is the number of samples.opt: The optimization algorithm to use during training (default = Adam(1e-3)).batchsize: The batch size for each iteration of gradient descent (default = 32).epochs: The number of epochs to train for (default = 100).verbosity: Whether to show a progress bar (default = 1) or not (0).
MixtureDensityNetworks.generate_data — Methodgenerate_data(
n_samples::Int64
) -> Tuple{Matrix{Float64}, Matrix{Float64}}
Generate some synthetic data for testing purposes.
Parameters
n_samples: The number of samples we want to generate.
Returns
The sythetic features X and labels Y as a tuple (X, Y).
MixtureDensityNetworks.likelihood_loss — Methodlikelihood_loss(
distributions::Vector{<:Distributions.MixtureModel{Distributions.Multivariate}},
y::Matrix{<:Real}
) -> Any
Conpute the negative log-likelihood loss for a set of labels y under a set of multivariate Gaussian Mixture Models.
Parameters
distributions: A vector of multivariate Gaussian Mixture Model distributions.y: A dxn matrix of labels where d is the dimension of each label and n is the number of samples.
MixtureDensityNetworks.likelihood_loss — Methodlikelihood_loss(
distributions::Vector{<:Distributions.MixtureModel{Distributions.Univariate}},
y::Matrix{<:Real}
) -> Any
Conpute the negative log-likelihood loss for a set of labels y under a set of univariate Gaussian Mixture Models.
Parameters
distributions: A vector of univariate Gaussian Mixture Model distributions.y: A 1xn matrix of labels where n is the number of samples.
MixtureDensityNetworks.predict_mode — Methodpredict_mode(
m::MixtureDensityNetwork,
X::Matrix{<:Real}
) -> Any
Predict the point associated with the highest probability in the conditional distribution P(Y|X).
Parameters
m: The model with which to generate a prediction.X: The input to be passed tom. Expected to be a matrix with dimensions dxn where n is the number of observations.
Returns
The mode of each distribution returned by m(X).