Class LossMixtureDensity
- java.lang.Object
-
- org.nd4j.linalg.lossfunctions.impl.LossMixtureDensity
-
- All Implemented Interfaces:
Serializable,ILossFunction
public class LossMixtureDensity extends Object implements ILossFunction
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classLossMixtureDensity.Builderstatic classLossMixtureDensity.MixtureDensityComponentsThis class is a data holder for the mixture density components for convenient manipulation.
-
Constructor Summary
Constructors Constructor Description LossMixtureDensity()
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static LossMixtureDensity.Builderbuilder()INDArraycomputeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask)This method returns the gradient of the cost function with respect to the output from the previous layer.Pair<Double,INDArray>computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average)Compute both the score (loss function value) and gradient.doublecomputeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average)Computes the aggregate score as a sum of all of the individual scores of each of the labels against each of the outputs of the network.INDArraycomputeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask)This method returns the score for each of the given outputs against the given set of labels.LossMixtureDensity.MixtureDensityComponentsextractComponents(INDArray output)intgetLabelWidth()Returns the width of each label vector.intgetNMixtures()Returns the number of gaussians this loss function will attempt to find.Stringname()The opName of this functionStringtoString()
-
-
-
Method Detail
-
extractComponents
public LossMixtureDensity.MixtureDensityComponents extractComponents(INDArray output)
-
computeScore
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average)
Computes the aggregate score as a sum of all of the individual scores of each of the labels against each of the outputs of the network. For the mixture density network, this is the negative log likelihood that the given labels fall within the probability distribution described by the mixture of gaussians of the network output.- Specified by:
computeScorein interfaceILossFunction- Parameters:
labels- Labels to score against the network.preOutput- Output of the network (before activation function has been called).activationFn- Activation function for the network.mask- Mask to be applied to labels (not used for MDN).average- Whether or not to return an average instead of a total score (not used).- Returns:
- Returns a single double which corresponds to the total score of all label values.
-
computeScoreArray
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask)
This method returns the score for each of the given outputs against the given set of labels. For a mixture density network, this is done by extracting the "alpha", "mu", and "sigma" components of each gaussian and computing the negative log likelihood that the labels fall within a linear combination of these gaussian distributions. The smaller the negative log likelihood, the higher the probability that the given labels actually would fall within the distribution. Therefore by minimizing the negative log likelihood, we get to a position of highest probability that the gaussian mixture explains the phenomenon.- Specified by:
computeScoreArrayin interfaceILossFunction- Parameters:
labels- Labels give the sample output that the network should be trying to converge on.preOutput- The output of the last layer (before applying the activation function).activationFn- The activation function of the current layer.mask- Mask to apply to score evaluation (not supported for this cost function).- Returns:
-
computeGradient
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask)
This method returns the gradient of the cost function with respect to the output from the previous layer. For this cost function, the gradient is derived from Bishop's paper "Mixture Density Networks" (1994) which gives an elegant closed-form expression for the derivatives with respect to each of the output components.- Specified by:
computeGradientin interfaceILossFunction- Parameters:
labels- Labels to train on.preOutput- Output of neural network before applying the final activation function.activationFn- Activation function of output layer.mask- Mask to apply to gradients.- Returns:
- Gradient of cost function with respect to preOutput parameters.
-
computeGradientAndScore
public Pair<Double,INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average)
Description copied from interface:ILossFunctionCompute both the score (loss function value) and gradient. This is equivalent to callingILossFunction.computeScore(INDArray, INDArray, IActivation, INDArray, boolean)andILossFunction.computeGradient(INDArray, INDArray, IActivation, INDArray)individually- Specified by:
computeGradientAndScorein interfaceILossFunction- Parameters:
labels- Label/expected outputpreOutput- Output of the model (neural network)activationFn- Activation function that should be applied to preOutputmask- Mask array; may be nullaverage- Whether the score should be averaged (divided by number of rows in labels/output) or not- Returns:
- The score (loss function value) and gradient
-
name
public String name()
The opName of this function- Specified by:
namein interfaceILossFunction- Returns:
-
getNMixtures
public int getNMixtures()
Returns the number of gaussians this loss function will attempt to find.- Returns:
- Number of gaussians to find.
-
getLabelWidth
public int getLabelWidth()
Returns the width of each label vector.- Returns:
- Width of label vectors expected.
-
builder
public static LossMixtureDensity.Builder builder()
-
-