Author: erans
Date: Thu Mar 21 16:22:02 2013
New Revision: 1459382
URL: http://svn.apache.org/r1459382
Log:
MATH-817
Algorithem for fitting of a multivariate normal distributions mixture
(implemented by Jared Becksfort).
Added "MixtureMultivariateNormalDistribution" class as "syntactic sugar".
Two unit tests are currently set to "@Ignore" (because they rely on "equals"
which the patch did not seem to implement "equals" consistently).
Added:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
(with props)
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
(with props)
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
(with props)
Added:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java?rev=1459382&view=auto
==============================================================================
---
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
(added)
+++
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
Thu Mar 21 16:22:02 2013
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.distribution;
+
+import java.util.List;
+import java.util.ArrayList;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * Multivariate normal mixture distribution.
+ * This class is mainly syntactic sugar.
+ *
+ * @see MixtureMultivariateRealDistribution
+ */
+public class MixtureMultivariateNormalDistribution
+ extends
MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
+ /**
+ * Creates a multivariate normal mixture distribution.
+ *
+ * @param weights Weights of each component.
+ * @param means Mean vector for each component.
+ * @param covariances Covariance matrix for each component.
+ */
+ public MixtureMultivariateNormalDistribution(double[] weights,
+ double[][] means,
+ double[][][] covariances) {
+ super(createComponents(weights, means, covariances));
+ }
+
+ /**
+ * Creates a mixture model from a list of distributions and their
+ * associated weights.
+ *
+ * @param components List of (weight, distribution) pairs from which to
sample.
+ */
+ public MixtureMultivariateNormalDistribution(List<Pair<Double,
MultivariateNormalDistribution>> components) {
+ super(components);
+ }
+
+ /**
+ * Creates a mixture model from a list of distributions and their
+ * associated weights.
+ *
+ * @param rng Random number generator.
+ * @param components Distributions from which to sample.
+ * @throws NotPositiveException if any of the weights is negative.
+ * @throws DimensionMismatchException if not all components have the same
+ * number of variables.
+ */
+ public MixtureMultivariateNormalDistribution(RandomGenerator rng,
+ List<Pair<Double,
MultivariateNormalDistribution>> components) {
+ super(rng, components);
+ }
+
+ /**
+ * @param weights Weights of each component.
+ * @param means Mean vector for each component.
+ * @param covariances Covariance matrix for each component.
+ * @return the list of components.
+ */
+ private static List<Pair<Double, MultivariateNormalDistribution>>
createComponents(double[] weights,
+
double[][] means,
+
double[][][] covariances) {
+ final List<Pair<Double, MultivariateNormalDistribution>> mvns
+ = new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+
+ for (int i = 0; i < weights.length; i++) {
+ final MultivariateNormalDistribution dist
+ = new MultivariateNormalDistribution(means[i], covariances[i]);
+
+ mvns.add(new Pair<Double,
MultivariateNormalDistribution>(weights[i], dist));
+ }
+
+ return mvns;
+ }
+}
Propchange:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
------------------------------------------------------------------------------
svn:keywords = Id Revision
Added:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java?rev=1459382&view=auto
==============================================================================
---
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
(added)
+++
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
Thu Mar 21 16:22:02 2013
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
under
+ * the License.
+ */
+package org.apache.commons.math3.distribution.fitting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import
org.apache.commons.math3.distribution.MixtureMultivariateRealDistribution;
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
+import
org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.exception.util.LocalizedFormats;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.SingularMatrixException;
+import org.apache.commons.math3.stat.correlation.Covariance;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * <a
href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf">
+ * Expectation-Maximization</a> algorithm for fitting the parameters of
+ * multivariate normal mixture model distributions.
+ *
+ * This implementation is based on
+ * <a href="http://cran.r-project.org/web/packages/mixtools/index.html">
+ * CRAN Mixtools</a>
+ *
+ * @version $Id$
+ * @since 3.2
+ */
+public class MultivariateNormalMixtureExpectationMaximization {
+ /**
+ * The data to fit.
+ */
+ private final double[][] data;
+ /**
+ * The model fit against the data.
+ */
+ private MixtureMultivariateNormalDistribution fittedModel;
+ /**
+ * The log likelihood of the data given the fitted model.
+ */
+ private double logLikelihood = 0d;
+ /**
+ * Default maximum number of iterations allowed per fitting process.
+ */
+ private final int defaultMaxIterations = 1000;
+ /**
+ * Default convergence threshold for fitting.
+ */
+ private final double defaultThreshold = 1E-5;
+
+ /**
+ * Creates an object to fit a multivariate normal mixture model to data.
+ *
+ * @param data Data to use in fitting procedure
+ * @throws NotStrictlyPositiveException if data has no rows
+ * @throws DimensionMismatchException if rows of data have different
numbers
+ * of columns
+ * @throws NumberIsTooSmallException if the number of columns in the data
is
+ * less than 2
+ */
+ public MultivariateNormalMixtureExpectationMaximization(double[][] data)
+ throws NotStrictlyPositiveException,
+ DimensionMismatchException,
+ NumberIsTooSmallException {
+ if (data.length < 1) {
+ throw new NotStrictlyPositiveException(data.length);
+ }
+
+ this.data = new double[data.length][data[0].length];
+
+ for (int i = 0; i < data.length; i++) {
+ if (data[i].length != data[0].length) {
+ // Jagged arrays not allowed
+ throw new DimensionMismatchException(data[i].length,
+ data[0].length);
+ }
+ if (data[i].length < 2) {
+ throw new
NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
+ data[i].length, 2, true);
+ }
+ this.data[i] = MathArrays.copyOf(data[i], data[i].length);
+ }
+ }
+
+ /**
+ * Fit a mixture model to the data supplied to the constructor.
+ *
+ * The quality of the fit depends on the concavity of the data provided to
+ * the constructor and the initial mixture provided to this function. If
the
+ * data has many local optima, multiple runs of the fitting function with
+ * different initial mixtures may be required to find the optimal solution.
+ * If a SingularMatrixException is encountered, it is possible that another
+ * initialization would work.
+ *
+ * @param initialMixture Model containing initial values of weights and
+ * multivariate normals
+ * @param maxIterations Maximum iterations allowed for fit
+ * @param threshold Convergence threshold computed as difference in
+ * logLikelihoods between successive iterations
+ * @throws SingularMatrixException if any component's covariance matrix is
+ * singular during fitting
+ * @throws NotStrictlyPositiveException if numComponents is less than one
+ * or threshold is less than Double.MIN_VALUE
+ * @throws DimensionMismatchException if initialMixture mean vector and
data
+ * number of columns are not equal
+ * @see #estimateMultivariateNormalMixtureModelDistribution
+ */
+ public void fit(final MixtureMultivariateNormalDistribution initialMixture,
+ final int maxIterations,
+ final double threshold)
+ throws SingularMatrixException,
+ NotStrictlyPositiveException,
+ DimensionMismatchException {
+ if (maxIterations < 1) {
+ throw new NotStrictlyPositiveException(maxIterations);
+ }
+
+ if (threshold < Double.MIN_VALUE) {
+ throw new NotStrictlyPositiveException(threshold);
+ }
+
+ final int n = data.length;
+
+ // Number of data columns. Jagged data already rejected in constructor,
+ // so we can assume the lengths of each row are equal.
+ final int numCols = data[0].length;
+ final int k = initialMixture.getComponents().size();
+
+ final int numMeanColumns
+ =
initialMixture.getComponents().get(0).getSecond().getMeans().length;
+
+ if (numMeanColumns != numCols) {
+ throw new DimensionMismatchException(numMeanColumns, numCols);
+ }
+
+ int numIterations = 0;
+ double previousLogLikelihood = 0d;
+
+ logLikelihood = Double.NEGATIVE_INFINITY;
+
+ // Initialize model to fit to initial mixture.
+ fittedModel = new
MixtureMultivariateNormalDistribution(initialMixture.getComponents());
+
+ while (numIterations++ <= maxIterations &&
+ Math.abs(previousLogLikelihood - logLikelihood) > threshold) {
+ previousLogLikelihood = logLikelihood;
+ double sumLogLikelihood = 0d;
+
+ // Mixture components
+ final List<Pair<Double, MultivariateNormalDistribution>> components
+ = fittedModel.getComponents();
+
+ // Weight and distribution of each component
+ final double[] weights = new double[k];
+
+ final MultivariateNormalDistribution[] mvns = new
MultivariateNormalDistribution[k];
+
+ for (int j = 0; j < k; j++) {
+ weights[j] = components.get(j).getFirst();
+ mvns[j] = components.get(j).getSecond();
+ }
+
+ // E-step: compute the data dependent parameters of the expectation
+ // function.
+ // The percentage of row's total density between a row and a
+ // component
+ final double[][] gamma = new double[n][k];
+
+ // Sum of gamma for each component
+ final double[] gammaSums = new double[k];
+
+ // Sum of gamma times its row for each each component
+ final double[][] gammaDataProdSums = new double[k][numCols];
+
+ for (int i = 0; i < n; i++) {
+ final double rowDensity = fittedModel.density(data[i]);
+ sumLogLikelihood += Math.log(rowDensity);
+
+ for (int j = 0; j < k; j++) {
+ gamma[i][j] = weights[j] * mvns[j].density(data[i])
+ / rowDensity;
+
+ gammaSums[j] += gamma[i][j];
+
+ for (int col = 0; col < numCols; col++) {
+ gammaDataProdSums[j][col] += gamma[i][j] *
data[i][col];
+ }
+ }
+ }
+
+ logLikelihood = sumLogLikelihood / n;
+
+ // M-step: compute the new parameters based on the expectation
+ // function.
+ final double[] newWeights = new double[k];
+ final double[][] newMeans = new double[k][numCols];
+
+ for (int j = 0; j < k; j++) {
+ newWeights[j] = gammaSums[j] / n;
+ for (int col = 0; col < numCols; col++) {
+ newMeans[j][col] = gammaDataProdSums[j][col] /
gammaSums[j];
+ }
+ }
+
+ // Compute new covariance matrices
+ final RealMatrix[] newCovMats = new RealMatrix[k];
+ for (int j = 0; j < k; j++) {
+ newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
+ }
+ for (int i = 0; i < n; i++) {
+ for (int j = 0; j < k; j++) {
+ final RealMatrix vec
+ = new
Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
+ final RealMatrix dataCov
+ =
vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
+ newCovMats[j] = newCovMats[j].add(dataCov);
+ }
+ }
+
+ // Converting to arrays for use by fitted model
+ final double[][][] newCovMatArrays = new
double[k][numCols][numCols];
+ for (int j = 0; j < k; j++) {
+ newCovMats[j] = newCovMats[j].scalarMultiply(1d /
gammaSums[j]);
+ newCovMatArrays[j] = newCovMats[j].getData();
+ }
+
+ // Update current model
+ fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
+ newMeans,
+
newCovMatArrays);
+ }
+
+ if (Math.abs(previousLogLikelihood - logLikelihood) > threshold) {
+ // Did not converge before the maximum number of iterations
+ throw new ConvergenceException();
+ }
+ }
+
+ /**
+ * Fit a mixture model to the data supplied to the constructor.
+ *
+ * The quality of the fit depends on the concavity of the data provided to
+ * the constructor and the initial mixture provided to this function. If
the
+ * data has many local optima, multiple runs of the fitting function with
+ * different initial mixtures may be required to find the optimal solution.
+ * If a SingularMatrixException is encountered, it is possible that another
+ * initialization would work.
+ *
+ * @param initialMixture Model containing initial values of weights and
+ * multivariate normals
+ * @throws SingularMatrixException if any component's covariance matrix is
+ * singular during fitting
+ * @throws NotStrictlyPositiveException if numComponents is less than one
or
+ * threshold is less than Double.MIN_VALUE
+ * @see #estimateMultivariateNormalMixtureModelDistribution
+ */
+ public void fit(MixtureMultivariateNormalDistribution initialMixture)
+ throws SingularMatrixException,
+ NotStrictlyPositiveException {
+ fit(initialMixture, defaultMaxIterations, defaultThreshold);
+ }
+
+ /**
+ * Helper method to create a multivariate normal mixture model which can be
+ * used to initialize {@link #fit(MixtureMultivariateRealDistribution)}.
+ *
+ * This method uses the data supplied to the constructor to try to
determine
+ * a good mixture model at which to start the fit, but it is not guaranteed
+ * to supply a model which will find the optimal solution or even converge.
+ *
+ * @param data Data to estimate distribution
+ * @param numComponents Number of components for estimated mixture
+ * @return Multivariate normal mixture model estimated from the data
+ * @throws NumberIsTooLargeException if {@code numComponents\ is greater
+ * than the number of data rows.
+ * @throws NumberIsTooSmallException if {@code numComponents < 2}.
+ * @throws NotStrictlyPositiveException if data has less than 2 rows
+ * @throws DimensionMismatchException if rows of data have different
numbers
+ * of columns
+ * @see #fit
+ */
+ public static MixtureMultivariateNormalDistribution estimate(final
double[][] data,
+ final int
numComponents)
+ throws NotStrictlyPositiveException,
+ DimensionMismatchException {
+ if (data.length < 2) {
+ throw new NotStrictlyPositiveException(data.length);
+ }
+ if (numComponents < 2) {
+ throw new NumberIsTooSmallException(numComponents, 2, true);
+ }
+ if (numComponents > data.length) {
+ throw new NumberIsTooLargeException(numComponents, data.length,
true);
+ }
+
+ final int numRows = data.length;
+ final int numCols = data[0].length;
+
+ // sort the data
+ final DataRow[] sortedData = new DataRow[numRows];
+ for (int i = 0; i < numRows; i++) {
+ sortedData[i] = new DataRow(data[i]);
+ }
+ Arrays.sort(sortedData);
+
+ final int totalBins = numComponents;
+
+ // uniform weight for each bin
+ final double weight = 1d / totalBins;
+
+ // components of mixture model to be created
+ final List<Pair<Double, MultivariateNormalDistribution>> components =
+ new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+
+ // create a component based on data in each bin
+ for (int binNumber = 1; binNumber <= totalBins; binNumber++) {
+ // minimum index from sorted data for this bin
+ final int minIndex
+ = (int) FastMath.max(0,
+ FastMath.floor((binNumber - 1) * numRows
/ totalBins));
+
+ // maximum index from sorted data for this bin
+ final int maxIndex
+ = (int) FastMath.ceil(binNumber * numRows / numComponents) - 1;
+
+ // number of data records that will be in this bin
+ final int numBinRows = maxIndex - minIndex + 1;
+
+ // data for this bin
+ final double[][] binData = new double[numBinRows][numCols];
+
+ // mean of each column for the data in the this bin
+ final double[] columnMeans = new double[numCols];
+
+ // populate bin and create component
+ for (int i = minIndex, iBin = 0; i <= maxIndex; i++, iBin++) {
+ for (int j = 0; j < numCols; j++) {
+ final double val = sortedData[i].getRow()[j];
+ columnMeans[j] += val;
+ binData[iBin][j] = val;
+ }
+ }
+
+ MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
+
+ // covariance matrix for this bin
+ final double[][] covMat
+ = new Covariance(binData).getCovarianceMatrix().getData();
+ final MultivariateNormalDistribution mvn
+ = new MultivariateNormalDistribution(columnMeans, covMat);
+
+ components.add(new Pair<Double,
MultivariateNormalDistribution>(weight, mvn));
+ }
+
+ return new MixtureMultivariateNormalDistribution(components);
+ }
+
+ /**
+ * Gets the log likelihood of the data under the fitted model.
+ *
+ * @return Log likelihood of data or zero of no data has been fit
+ */
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ /**
+ * Gets the fitted model.
+ *
+ * @return fitted model or {@code null} if no fit has been performed yet.
+ */
+ public MixtureMultivariateNormalDistribution getFittedModel() {
+ return new
MixtureMultivariateNormalDistribution(fittedModel.getComponents());
+ }
+
+ /**
+ * Class used for sorting user-supplied data.
+ */
+ private static class DataRow implements Comparable<DataRow> {
+ /** One data row. */
+ private final double[] row;
+ /** Mean of the data row. */
+ private Double mean;
+
+ /**
+ * Create a data row.
+ * @param data Data to use for the row
+ */
+ DataRow(final double[] data) {
+ // Store reference.
+ row = data;
+ // Compute mean.
+ mean = 0d;
+ for (int i = 0; i < data.length; i++) {
+ mean += data[i];
+ }
+ mean /= data.length;
+ }
+
+ /**
+ * Compare two data rows.
+ * @param other The other row
+ * @return int for sorting
+ */
+ public int compareTo(final DataRow other) {
+ return mean.compareTo(other.mean);
+ }
+
+ /**
+ * Get a data row.
+ * @return data row array
+ */
+ public double[] getRow() {
+ return row;
+ }
+ }
+}
+
Propchange:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange:
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
------------------------------------------------------------------------------
svn:keywords = Id Revision
Added:
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java?rev=1459382&view=auto
==============================================================================
---
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
(added)
+++
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
Thu Mar 21 16:22:02 2013
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
under
+ * the License.
+ */
+package org.apache.commons.math3.distribution.fitting;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import
org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.util.Pair;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.Ignore;
+
+/**
+ * Test that demonstrates the use of
+ * {@link MultivariateNormalMixtureExpectationMaximization}.
+ */
+public class MultivariateNormalMixtureExpectationMaximizationTest {
+
+ // TODO reject initial mixes where means/covMats not computable with data
+ // numCols
+
+ @Test(expected = NotStrictlyPositiveException.class)
+ public void testNonEmptyData() {
+ // Should not accept empty data
+ new MultivariateNormalMixtureExpectationMaximization(new double[][]
{});
+ }
+
+ @Test(expected = DimensionMismatchException.class)
+ public void testNonJaggedData() {
+ // Reject data with nonconstant numbers of columns
+ double[][] data = new double[][] {
+ { 1, 2, 3 },
+ { 4, 5, 6, 7 },
+ };
+ new MultivariateNormalMixtureExpectationMaximization(data);
+ }
+
+ @Test(expected = NumberIsTooSmallException.class)
+ public void testMultipleColumnsRequired() {
+ // Data should have at least 2 columns
+ double[][] data = new double[][] {
+ { 1 }, { 2 }
+ };
+ new MultivariateNormalMixtureExpectationMaximization(data);
+ }
+
+ @Test(expected = NotStrictlyPositiveException.class)
+ public void testMaxIterationsPositive() {
+ // Maximum iterations for fit must be positive integer
+ double[][] data = getTestSamples();
+ MultivariateNormalMixtureExpectationMaximization fitter =
+ new MultivariateNormalMixtureExpectationMaximization(data);
+
+ MixtureMultivariateNormalDistribution
+ initialMix =
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+ fitter.fit(initialMix, 0, 1E-5);
+ }
+
+ @Test(expected = NotStrictlyPositiveException.class)
+ public void testThresholdPositive() {
+ // Maximum iterations for fit must be positive
+ double[][] data = getTestSamples();
+ MultivariateNormalMixtureExpectationMaximization fitter =
+ new MultivariateNormalMixtureExpectationMaximization(
+ data);
+
+ MixtureMultivariateNormalDistribution
+ initialMix =
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+ fitter.fit(initialMix, 1000, 0);
+ }
+
+ @Test(expected = ConvergenceException.class)
+ public void testConvergenceException() {
+ // ConvergenceException thrown if fit terminates before threshold met
+ double[][] data = getTestSamples();
+ MultivariateNormalMixtureExpectationMaximization fitter
+ = new MultivariateNormalMixtureExpectationMaximization(data);
+
+ MixtureMultivariateNormalDistribution
+ initialMix =
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+ // 5 iterations not enough to meet convergence threshold
+ fitter.fit(initialMix, 5, 1E-5);
+ }
+
+ @Test(expected = DimensionMismatchException.class)
+ public void testIncompatibleIntialMixture() {
+ // Data has 3 columns
+ double[][] data = new double[][] {
+ { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
+ };
+ double[] weights = new double[] { 0.5, 0.5 };
+
+ // These distributions are compatible with 2-column data, not 3-column
+ // data
+ MultivariateNormalDistribution[] mvns = new
MultivariateNormalDistribution[2];
+
+ mvns[0] = new MultivariateNormalDistribution(new double[] {
+ -0.0021722935000328823, 3.5432892936887908 },
+ new double[][] {
+ { 4.537422569229048, 3.5266152281729304 },
+ { 3.5266152281729304, 6.175448814169779 } });
+ mvns[1] = new MultivariateNormalDistribution(new double[] {
+ 5.090902706507635, 8.68540656355283 }, new double[][] {
+ { 2.886778573963039, 1.5257474543463154 },
+ { 1.5257474543463154, 3.3794567673616918 } });
+
+ // Create components and mixture
+ List<Pair<Double, MultivariateNormalDistribution>> components =
+ new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+ components.add(new Pair<Double, MultivariateNormalDistribution>(
+ weights[0], mvns[0]));
+ components.add(new Pair<Double, MultivariateNormalDistribution>(
+ weights[1], mvns[1]));
+
+ MixtureMultivariateNormalDistribution badInitialMix
+ = new MixtureMultivariateNormalDistribution(components);
+
+ MultivariateNormalMixtureExpectationMaximization fitter
+ = new MultivariateNormalMixtureExpectationMaximization(data);
+
+ fitter.fit(badInitialMix);
+ }
+
+ @Ignore@Test
+ public void testInitialMixture() {
+ // Testing initial mixture estimated from data
+ double[] correctWeights = new double[] { 0.5, 0.5 };
+
+ MultivariateNormalDistribution[] correctMVNs = new
MultivariateNormalDistribution[2];
+
+ correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
+ -0.0021722935000328823, 3.5432892936887908 },
+ new double[][] {
+ { 4.537422569229048, 3.5266152281729304 },
+ { 3.5266152281729304, 6.175448814169779 } });
+ correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
+ 5.090902706507635, 8.68540656355283 }, new double[][] {
+ { 2.886778573963039, 1.5257474543463154 },
+ { 1.5257474543463154, 3.3794567673616918 } });
+
+ final MixtureMultivariateNormalDistribution initialMix
+ =
MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
+
+ int i = 0;
+ for (Pair<Double, MultivariateNormalDistribution> component :
initialMix
+ .getComponents()) {
+ Assert.assertEquals(correctWeights[i], component.getFirst(),
+ Math.ulp(1d));
+ Assert.assertEquals(correctMVNs[i], component.getSecond());
+ i++;
+ }
+ }
+
+ @Ignore@Test
+ public void testFit() {
+ // Test that the loglikelihood, weights, and models are determined and
+ // fitted correctly
+ double[][] data = getTestSamples();
+ double correctLogLikelihood = -4.292431006791994;
+ double[] correctWeights = new double[] { 0.2962324189652912,
0.7037675810347089 };
+ MultivariateNormalDistribution[] correctMVNs = new
MultivariateNormalDistribution[2];
+ correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
+ -1.4213112715121132, 1.6924690505757753 },
+ new double[][] {
+ { 1.739356907285747, -0.5867644251487614 },
+ { -0.5867644251487614, 1.0232932029324642 } });
+
+ correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
+ 4.213612224374709, 7.975621325853645 },
+ new double[][] {
+ { 4.245384898007161, 2.5797798966382155 },
+ { 2.5797798966382155, 3.9200272522448367 } });
+
+ MultivariateNormalMixtureExpectationMaximization fitter
+ = new MultivariateNormalMixtureExpectationMaximization(data);
+
+ MixtureMultivariateNormalDistribution initialMix
+ = MultivariateNormalMixtureExpectationMaximization.estimate(data,
2);
+ fitter.fit(initialMix);
+ MixtureMultivariateNormalDistribution fittedMix =
fitter.getFittedModel();
+ List<Pair<Double, MultivariateNormalDistribution>> components =
fittedMix.getComponents();
+
+ Assert.assertEquals(correctLogLikelihood,
+ fitter.getLogLikelihood(),
+ Math.ulp(1d));
+
+ int i = 0;
+ for (Pair<Double, MultivariateNormalDistribution> component :
components) {
+ double weight = component.getFirst();
+ MultivariateNormalDistribution mvn = component.getSecond();
+ Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
+ Assert.assertEquals(correctMVNs[i], mvn);
+ i++;
+ }
+ }
+
+ private double[][] getTestSamples() {
+ // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
+ // [4, 8.2]
+ return new double[][] { { 7.358553610469948, 11.31260831446758 },
+ { 7.175770420124739, 8.988812210204454 },
+ { 4.324151905768422, 6.837727899051482 },
+ { 2.157832219173036, 6.317444585521968 },
+ { -1.890157421896651, 1.74271202875498 },
+ { 0.8922409354455803, 1.999119343923781 },
+ { 3.396949764787055, 6.813170372579068 },
+ { -2.057498232686068, -0.002522983830852255 },
+ { 6.359932157365045, 8.343600029975851 },
+ { 3.353102234276168, 7.087541882898689 },
+ { -1.763877221595639, 0.9688890460330644 },
+ { 6.151457185125111, 9.075011757431174 },
+ { 4.281597398048899, 5.953270070976117 },
+ { 3.549576703974894, 8.616038155992861 },
+ { 6.004706732349854, 8.959423391087469 },
+ { 2.802915014676262, 6.285676742173564 },
+ { -0.6029879029880616, 1.083332958357485 },
+ { 3.631827105398369, 6.743428504049444 },
+ { 6.161125014007315, 9.60920569689001 },
+ { -1.049582894255342, 0.2020017892080281 },
+ { 3.910573022688315, 8.19609909534937 },
+ { 8.180454017634863, 7.861055769719962 },
+ { 1.488945440439716, 8.02699903761247 },
+ { 4.813750847823778, 12.34416881332515 },
+ { 0.0443208501259158, 5.901148093240691 },
+ { 4.416417235068346, 4.465243084006094 },
+ { 4.0002433603072, 6.721937850166174 },
+ { 3.190113818788205, 10.51648348411058 },
+ { 4.493600914967883, 7.938224231022314 },
+ { -3.675669533266189, 4.472845076673303 },
+ { 6.648645511703989, 12.03544085965724 },
+ { -1.330031331404445, 1.33931042964811 },
+ { -3.812111460708707, 2.50534195568356 },
+ { 5.669339356648331, 6.214488981177026 },
+ { 1.006596727153816, 1.51165463112716 },
+ { 5.039466365033024, 7.476532610478689 },
+ { 4.349091929968925, 7.446356406259756 },
+ { -1.220289665119069, 3.403926955951437 },
+ { 5.553003979122395, 6.886518211202239 },
+ { 2.274487732222856, 7.009541508533196 },
+ { 4.147567059965864, 7.34025244349202 },
+ { 4.083882618965819, 6.362852861075623 },
+ { 2.203122344647599, 7.260295257904624 },
+ { -2.147497550770442, 1.262293431529498 },
+ { 2.473700950426512, 6.558900135505638 },
+ { 8.267081298847554, 12.10214104577748 },
+ { 6.91977329776865, 9.91998488301285 },
+ { 0.1680479852730894, 6.28286034168897 },
+ { -1.268578659195158, 2.326711221485755 },
+ { 1.829966451374701, 6.254187605304518 },
+ { 5.648849025754848, 9.330002040750291 },
+ { -2.302874793257666, 3.585545172776065 },
+ { -2.629218791709046, 2.156215538500288 },
+ { 4.036618140700114, 10.2962785719958 },
+ { 0.4616386422783874, 0.6782756325806778 },
+ { -0.3447896073408363, 0.4999834691645118 },
+ { -0.475281453118318, 1.931470384180492 },
+ { 2.382509690609731, 6.071782429815853 },
+ { -3.203934441889096, 2.572079552602468 },
+ { 8.465636032165087, 13.96462998683518 },
+ { 2.36755660870416, 5.7844595007273 },
+ { 0.5935496528993371, 1.374615871358943 },
+ { -2.467481505748694, 2.097224634713005 },
+ { 4.27867444328542, 10.24772361238549 },
+ { -2.013791907543137, 2.013799426047639 },
+ { 6.424588084404173, 9.185334939684516 },
+ { -0.8448238876802175, 0.5447382022282812 },
+ { 1.342955703473923, 8.645456317633556 },
+ { 3.108712208751979, 8.512156853800064 },
+ { 4.343205178315472, 8.056869549234374 },
+ { -2.971767642212396, 3.201180146824761 },
+ { 2.583820931523672, 5.459873414473854 },
+ { 4.209139115268925, 8.171098193546225 },
+ { 0.4064909057902746, 1.454390775518743 },
+ { 3.068642411145223, 6.959485153620035 },
+ { 6.085968972900461, 7.391429799500965 },
+ { -1.342265795764202, 1.454550012997143 },
+ { 6.249773274516883, 6.290269880772023 },
+ { 4.986225847822566, 7.75266344868907 },
+ { 7.642443254378944, 10.19914817500263 },
+ { 6.438181159163673, 8.464396764810347 },
+ { 2.520859761025108, 7.68222425260111 },
+ { 2.883699944257541, 6.777960331348503 },
+ { 2.788004550956599, 6.634735386652733 },
+ { 3.331661231995638, 5.794191300046592 },
+ { 3.526172276645504, 6.710802266815884 },
+ { 3.188298528138741, 10.34495528210205 },
+ { 0.7345539486114623, 5.807604004180681 },
+ { 1.165044595880125, 7.830121829295257 },
+ { 7.146962523500671, 11.62995162065415 },
+ { 7.813872137162087, 10.62827008714735 },
+ { 3.118099164870063, 8.286003148186371 },
+ { -1.708739286262571, 1.561026755374264 },
+ { 1.786163047580084, 4.172394388214604 },
+ { 3.718506403232386, 7.807752990130349 },
+ { 6.167414046828899, 10.01104941031293 },
+ { -1.063477247689196, 1.61176085846339 },
+ { -3.396739609433642, 0.7127911050002151 },
+ { 2.438885945896797, 7.353011138689225 },
+ { -0.2073204144780931, 0.850771146627012 }, };
+ }
+}
Propchange:
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange:
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
------------------------------------------------------------------------------
svn:keywords = Id Revision