Author: erans
Date: Thu Mar 21 16:22:02 2013
New Revision: 1459382

URL: http://svn.apache.org/r1459382
Log:
MATH-817
Algorithem for fitting of a multivariate normal distributions mixture
(implemented by Jared Becksfort).
Added "MixtureMultivariateNormalDistribution" class as "syntactic sugar".
Two unit tests are currently set to "@Ignore" (because they rely on "equals"
which the patch did not seem to implement "equals" consistently).

Added:
    
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
   (with props)
    
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/
    
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
   (with props)
    
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/
    
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
   (with props)

Added: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
URL: 
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java?rev=1459382&view=auto
==============================================================================
--- 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
 (added)
+++ 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
 Thu Mar 21 16:22:02 2013
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math3.distribution;
+
+import java.util.List;
+import java.util.ArrayList;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * Multivariate normal mixture distribution.
+ * This class is mainly syntactic sugar.
+ *
+ * @see MixtureMultivariateRealDistribution
+ */
+public class MixtureMultivariateNormalDistribution
+    extends 
MixtureMultivariateRealDistribution<MultivariateNormalDistribution> {
+    /**
+     * Creates a multivariate normal mixture distribution.
+     *
+     * @param weights Weights of each component.
+     * @param means Mean vector for each component.
+     * @param covariances Covariance matrix for each component.
+     */
+    public MixtureMultivariateNormalDistribution(double[] weights,
+                                                 double[][] means,
+                                                 double[][][] covariances) {
+        super(createComponents(weights, means, covariances));
+    }
+
+    /**
+     * Creates a mixture model from a list of distributions and their
+     * associated weights.
+     *
+     * @param components List of (weight, distribution) pairs from which to 
sample.
+     */
+    public MixtureMultivariateNormalDistribution(List<Pair<Double, 
MultivariateNormalDistribution>> components) {
+        super(components);
+    }
+
+    /**
+     * Creates a mixture model from a list of distributions and their
+     * associated weights.
+     *
+     * @param rng Random number generator.
+     * @param components Distributions from which to sample.
+     * @throws NotPositiveException if any of the weights is negative.
+     * @throws DimensionMismatchException if not all components have the same
+     * number of variables.
+     */
+    public MixtureMultivariateNormalDistribution(RandomGenerator rng,
+                                                 List<Pair<Double, 
MultivariateNormalDistribution>> components) {
+        super(rng, components);
+    }
+
+    /**
+     * @param weights Weights of each component.
+     * @param means Mean vector for each component.
+     * @param covariances Covariance matrix for each component.
+     * @return the list of components.
+     */
+    private static List<Pair<Double, MultivariateNormalDistribution>> 
createComponents(double[] weights,
+                                                                               
        double[][] means,
+                                                                               
        double[][][] covariances) {
+        final List<Pair<Double, MultivariateNormalDistribution>> mvns
+            = new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+
+        for (int i = 0; i < weights.length; i++) {
+            final MultivariateNormalDistribution dist
+                = new MultivariateNormalDistribution(means[i], covariances[i]);
+
+            mvns.add(new Pair<Double, 
MultivariateNormalDistribution>(weights[i], dist));
+        }
+
+        return mvns;
+    }
+}

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/MixtureMultivariateNormalDistribution.java
------------------------------------------------------------------------------
    svn:keywords = Id Revision

Added: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
URL: 
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java?rev=1459382&view=auto
==============================================================================
--- 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
 (added)
+++ 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
 Thu Mar 21 16:22:02 2013
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations 
under
+ * the License.
+ */
+package org.apache.commons.math3.distribution.fitting;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import 
org.apache.commons.math3.distribution.MixtureMultivariateRealDistribution;
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
+import 
org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.exception.NumberIsTooLargeException;
+import org.apache.commons.math3.exception.util.LocalizedFormats;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.SingularMatrixException;
+import org.apache.commons.math3.stat.correlation.Covariance;
+import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.MathArrays;
+import org.apache.commons.math3.util.Pair;
+
+/**
+ * <a 
href="https://www.ee.washington.edu/techsite/papers/documents/UWEETR-2010-0002.pdf";>
+ * Expectation-Maximization</a> algorithm for fitting the parameters of
+ * multivariate normal mixture model distributions.
+ *
+ * This implementation is based on
+ * <a href="http://cran.r-project.org/web/packages/mixtools/index.html";>
+ *  CRAN Mixtools</a>
+ *
+ * @version $Id$
+ * @since 3.2
+ */
+public class MultivariateNormalMixtureExpectationMaximization {
+    /**
+     * The data to fit.
+     */
+    private final double[][] data;
+    /**
+     * The model fit against the data.
+     */
+    private MixtureMultivariateNormalDistribution fittedModel;
+    /**
+     * The log likelihood of the data given the fitted model.
+     */
+    private double logLikelihood = 0d;
+    /**
+     * Default maximum number of iterations allowed per fitting process.
+     */
+    private final int defaultMaxIterations = 1000;
+    /**
+     * Default convergence threshold for fitting.
+     */
+    private final double defaultThreshold = 1E-5;
+
+    /**
+     * Creates an object to fit a multivariate normal mixture model to data.
+     *
+     * @param data Data to use in fitting procedure
+     * @throws NotStrictlyPositiveException if data has no rows
+     * @throws DimensionMismatchException if rows of data have different 
numbers
+     *             of columns
+     * @throws NumberIsTooSmallException if the number of columns in the data 
is
+     *             less than 2
+     */
+    public MultivariateNormalMixtureExpectationMaximization(double[][] data)
+        throws NotStrictlyPositiveException,
+               DimensionMismatchException,
+               NumberIsTooSmallException {
+        if (data.length < 1) {
+            throw new NotStrictlyPositiveException(data.length);
+        }
+
+        this.data = new double[data.length][data[0].length];
+
+        for (int i = 0; i < data.length; i++) {
+            if (data[i].length != data[0].length) {
+                // Jagged arrays not allowed
+                throw new DimensionMismatchException(data[i].length,
+                                                     data[0].length);
+            }
+            if (data[i].length < 2) {
+                throw new 
NumberIsTooSmallException(LocalizedFormats.NUMBER_TOO_SMALL,
+                                                    data[i].length, 2, true);
+            }
+            this.data[i] = MathArrays.copyOf(data[i], data[i].length);
+        }
+    }
+
+    /**
+     * Fit a mixture model to the data supplied to the constructor.
+     *
+     * The quality of the fit depends on the concavity of the data provided to
+     * the constructor and the initial mixture provided to this function. If 
the
+     * data has many local optima, multiple runs of the fitting function with
+     * different initial mixtures may be required to find the optimal solution.
+     * If a SingularMatrixException is encountered, it is possible that another
+     * initialization would work.
+     *
+     * @param initialMixture Model containing initial values of weights and
+     *            multivariate normals
+     * @param maxIterations Maximum iterations allowed for fit
+     * @param threshold Convergence threshold computed as difference in
+     *             logLikelihoods between successive iterations
+     * @throws SingularMatrixException if any component's covariance matrix is
+     *             singular during fitting
+     * @throws NotStrictlyPositiveException if numComponents is less than one
+     *             or threshold is less than Double.MIN_VALUE
+     * @throws DimensionMismatchException if initialMixture mean vector and 
data
+     *             number of columns are not equal
+     * @see #estimateMultivariateNormalMixtureModelDistribution
+     */
+    public void fit(final MixtureMultivariateNormalDistribution initialMixture,
+                    final int maxIterations,
+                    final double threshold)
+            throws SingularMatrixException,
+                   NotStrictlyPositiveException,
+                   DimensionMismatchException {
+        if (maxIterations < 1) {
+            throw new NotStrictlyPositiveException(maxIterations);
+        }
+
+        if (threshold < Double.MIN_VALUE) {
+            throw new NotStrictlyPositiveException(threshold);
+        }
+
+        final int n = data.length;
+
+        // Number of data columns. Jagged data already rejected in constructor,
+        // so we can assume the lengths of each row are equal.
+        final int numCols = data[0].length;
+        final int k = initialMixture.getComponents().size();
+
+        final int numMeanColumns
+            = 
initialMixture.getComponents().get(0).getSecond().getMeans().length;
+
+        if (numMeanColumns != numCols) {
+            throw new DimensionMismatchException(numMeanColumns, numCols);
+        }
+
+        int numIterations = 0;
+        double previousLogLikelihood = 0d;
+
+        logLikelihood = Double.NEGATIVE_INFINITY;
+
+        // Initialize model to fit to initial mixture.
+        fittedModel = new 
MixtureMultivariateNormalDistribution(initialMixture.getComponents());
+
+        while (numIterations++ <= maxIterations &&
+               Math.abs(previousLogLikelihood - logLikelihood) > threshold) {
+            previousLogLikelihood = logLikelihood;
+            double sumLogLikelihood = 0d;
+
+            // Mixture components
+            final List<Pair<Double, MultivariateNormalDistribution>> components
+                = fittedModel.getComponents();
+
+            // Weight and distribution of each component
+            final double[] weights = new double[k];
+
+            final MultivariateNormalDistribution[] mvns = new 
MultivariateNormalDistribution[k];
+
+            for (int j = 0; j < k; j++) {
+                weights[j] = components.get(j).getFirst();
+                mvns[j] = components.get(j).getSecond();
+            }
+
+            // E-step: compute the data dependent parameters of the expectation
+            // function.
+            // The percentage of row's total density between a row and a
+            // component
+            final double[][] gamma = new double[n][k];
+
+            // Sum of gamma for each component
+            final double[] gammaSums = new double[k];
+
+            // Sum of gamma times its row for each each component
+            final double[][] gammaDataProdSums = new double[k][numCols];
+
+            for (int i = 0; i < n; i++) {
+                final double rowDensity = fittedModel.density(data[i]);
+                sumLogLikelihood += Math.log(rowDensity);
+
+                for (int j = 0; j < k; j++) {
+                    gamma[i][j] = weights[j] * mvns[j].density(data[i])
+                            / rowDensity;
+
+                    gammaSums[j] += gamma[i][j];
+
+                    for (int col = 0; col < numCols; col++) {
+                        gammaDataProdSums[j][col] += gamma[i][j] * 
data[i][col];
+                    }
+                }
+            }
+
+            logLikelihood = sumLogLikelihood / n;
+
+            // M-step: compute the new parameters based on the expectation
+            // function.
+            final double[] newWeights = new double[k];
+            final double[][] newMeans = new double[k][numCols];
+
+            for (int j = 0; j < k; j++) {
+                newWeights[j] = gammaSums[j] / n;
+                for (int col = 0; col < numCols; col++) {
+                    newMeans[j][col] = gammaDataProdSums[j][col] / 
gammaSums[j];
+                }
+            }
+
+            // Compute new covariance matrices
+            final RealMatrix[] newCovMats = new RealMatrix[k];
+            for (int j = 0; j < k; j++) {
+                newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
+            }
+            for (int i = 0; i < n; i++) {
+                for (int j = 0; j < k; j++) {
+                    final RealMatrix vec
+                        = new 
Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
+                    final RealMatrix dataCov
+                        = 
vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
+                    newCovMats[j] = newCovMats[j].add(dataCov);
+                }
+            }
+
+            // Converting to arrays for use by fitted model
+            final double[][][] newCovMatArrays = new 
double[k][numCols][numCols];
+            for (int j = 0; j < k; j++) {
+                newCovMats[j] = newCovMats[j].scalarMultiply(1d / 
gammaSums[j]);
+                newCovMatArrays[j] = newCovMats[j].getData();
+            }
+
+            // Update current model
+            fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
+                                                                    newMeans,
+                                                                    
newCovMatArrays);
+        }
+
+        if (Math.abs(previousLogLikelihood - logLikelihood) > threshold) {
+            // Did not converge before the maximum number of iterations
+            throw new ConvergenceException();
+        }
+    }
+
+    /**
+     * Fit a mixture model to the data supplied to the constructor.
+     *
+     * The quality of the fit depends on the concavity of the data provided to
+     * the constructor and the initial mixture provided to this function. If 
the
+     * data has many local optima, multiple runs of the fitting function with
+     * different initial mixtures may be required to find the optimal solution.
+     * If a SingularMatrixException is encountered, it is possible that another
+     * initialization would work.
+     *
+     * @param initialMixture Model containing initial values of weights and
+     *            multivariate normals
+     * @throws SingularMatrixException if any component's covariance matrix is
+     *             singular during fitting
+     * @throws NotStrictlyPositiveException if numComponents is less than one 
or
+     *             threshold is less than Double.MIN_VALUE
+     * @see #estimateMultivariateNormalMixtureModelDistribution
+     */
+    public void fit(MixtureMultivariateNormalDistribution initialMixture)
+        throws SingularMatrixException,
+               NotStrictlyPositiveException {
+        fit(initialMixture, defaultMaxIterations, defaultThreshold);
+    }
+
+    /**
+     * Helper method to create a multivariate normal mixture model which can be
+     * used to initialize {@link #fit(MixtureMultivariateRealDistribution)}.
+     *
+     * This method uses the data supplied to the constructor to try to 
determine
+     * a good mixture model at which to start the fit, but it is not guaranteed
+     * to supply a model which will find the optimal solution or even converge.
+     *
+     * @param data Data to estimate distribution
+     * @param numComponents Number of components for estimated mixture
+     * @return Multivariate normal mixture model estimated from the data
+     * @throws NumberIsTooLargeException if {@code numComponents\ is greater
+     * than the number of data rows.
+     * @throws NumberIsTooSmallException if {@code numComponents < 2}.
+     * @throws NotStrictlyPositiveException if data has less than 2 rows
+     * @throws DimensionMismatchException if rows of data have different 
numbers
+     *             of columns
+     * @see #fit
+     */
+    public static MixtureMultivariateNormalDistribution estimate(final 
double[][] data,
+                                                                 final int 
numComponents)
+        throws NotStrictlyPositiveException,
+               DimensionMismatchException {
+        if (data.length < 2) {
+            throw new NotStrictlyPositiveException(data.length);
+        }
+        if (numComponents < 2) {
+            throw new NumberIsTooSmallException(numComponents, 2, true);
+        }
+        if (numComponents > data.length) {
+            throw new NumberIsTooLargeException(numComponents, data.length, 
true);
+        }
+
+        final int numRows = data.length;
+        final int numCols = data[0].length;
+
+        // sort the data
+        final DataRow[] sortedData = new DataRow[numRows];
+        for (int i = 0; i < numRows; i++) {
+            sortedData[i] = new DataRow(data[i]);
+        }
+        Arrays.sort(sortedData);
+
+        final int totalBins = numComponents;
+
+        // uniform weight for each bin
+        final double weight = 1d / totalBins;
+
+        // components of mixture model to be created
+        final List<Pair<Double, MultivariateNormalDistribution>> components =
+                new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+
+        // create a component based on data in each bin
+        for (int binNumber = 1; binNumber <= totalBins; binNumber++) {
+            // minimum index from sorted data for this bin
+            final int minIndex
+                = (int) FastMath.max(0,
+                                     FastMath.floor((binNumber - 1) * numRows 
/ totalBins));
+
+            // maximum index from sorted data for this bin
+            final int maxIndex
+                = (int) FastMath.ceil(binNumber * numRows / numComponents) - 1;
+
+            // number of data records that will be in this bin
+            final int numBinRows = maxIndex - minIndex + 1;
+
+            // data for this bin
+            final double[][] binData = new double[numBinRows][numCols];
+
+            // mean of each column for the data in the this bin
+            final double[] columnMeans = new double[numCols];
+
+            // populate bin and create component
+            for (int i = minIndex, iBin = 0; i <= maxIndex; i++, iBin++) {
+                for (int j = 0; j < numCols; j++) {
+                    final double val = sortedData[i].getRow()[j];
+                    columnMeans[j] += val;
+                    binData[iBin][j] = val;
+                }
+            }
+
+            MathArrays.scaleInPlace(1d / numBinRows, columnMeans);
+
+            // covariance matrix for this bin
+            final double[][] covMat
+                = new Covariance(binData).getCovarianceMatrix().getData();
+            final MultivariateNormalDistribution mvn
+                = new MultivariateNormalDistribution(columnMeans, covMat);
+
+            components.add(new Pair<Double, 
MultivariateNormalDistribution>(weight, mvn));
+        }
+
+        return new MixtureMultivariateNormalDistribution(components);
+    }
+
+    /**
+     * Gets the log likelihood of the data under the fitted model.
+     *
+     * @return Log likelihood of data or zero of no data has been fit
+     */
+    public double getLogLikelihood() {
+        return logLikelihood;
+    }
+
+    /**
+     * Gets the fitted model.
+     *
+     * @return fitted model or {@code null} if no fit has been performed yet.
+     */
+    public MixtureMultivariateNormalDistribution getFittedModel() {
+        return new 
MixtureMultivariateNormalDistribution(fittedModel.getComponents());
+    }
+
+    /**
+     * Class used for sorting user-supplied data.
+     */
+    private static class DataRow implements Comparable<DataRow> {
+        /** One data row. */
+        private final double[] row;
+        /** Mean of the data row. */
+        private Double mean;
+
+        /**
+         * Create a data row.
+         * @param data Data to use for the row
+         */
+        DataRow(final double[] data) {
+            // Store reference.
+            row = data;
+            // Compute mean.
+            mean = 0d;
+            for (int i = 0; i < data.length; i++) {
+                mean += data[i];
+            }
+            mean /= data.length;
+        }
+
+        /**
+         * Compare two data rows.
+         * @param other The other row
+         * @return int for sorting
+         */
+        public int compareTo(final DataRow other) {
+            return mean.compareTo(other.mean);
+        }
+
+        /**
+         * Get a data row.
+         * @return data row array
+         */
+        public double[] getRow() {
+            return row;
+        }
+    }
+}
+

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: 
commons/proper/math/trunk/src/main/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
------------------------------------------------------------------------------
    svn:keywords = Id Revision

Added: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
URL: 
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java?rev=1459382&view=auto
==============================================================================
--- 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
 (added)
+++ 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
 Thu Mar 21 16:22:02 2013
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations 
under
+ * the License.
+ */
+package org.apache.commons.math3.distribution.fitting;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import 
org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
+import org.apache.commons.math3.exception.ConvergenceException;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.util.Pair;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.Ignore;
+
+/**
+ * Test that demonstrates the use of
+ * {@link MultivariateNormalMixtureExpectationMaximization}.
+ */
+public class MultivariateNormalMixtureExpectationMaximizationTest {
+
+    // TODO reject initial mixes where means/covMats not computable with data
+    // numCols
+
+    @Test(expected = NotStrictlyPositiveException.class)
+    public void testNonEmptyData() {
+        // Should not accept empty data
+        new MultivariateNormalMixtureExpectationMaximization(new double[][] 
{});
+    }
+
+    @Test(expected = DimensionMismatchException.class)
+    public void testNonJaggedData() {
+        // Reject data with nonconstant numbers of columns
+        double[][] data = new double[][] {
+                { 1, 2, 3 },
+                { 4, 5, 6, 7 },
+        };
+        new MultivariateNormalMixtureExpectationMaximization(data);
+    }
+
+    @Test(expected = NumberIsTooSmallException.class)
+    public void testMultipleColumnsRequired() {
+        // Data should have at least 2 columns
+        double[][] data = new double[][] {
+                { 1 }, { 2 }
+        };
+        new MultivariateNormalMixtureExpectationMaximization(data);
+    }
+
+    @Test(expected = NotStrictlyPositiveException.class)
+    public void testMaxIterationsPositive() {
+        // Maximum iterations for fit must be positive integer
+        double[][] data = getTestSamples();
+        MultivariateNormalMixtureExpectationMaximization fitter =
+                new MultivariateNormalMixtureExpectationMaximization(data);
+
+        MixtureMultivariateNormalDistribution
+            initialMix = 
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+        fitter.fit(initialMix, 0, 1E-5);
+    }
+
+    @Test(expected = NotStrictlyPositiveException.class)
+    public void testThresholdPositive() {
+        // Maximum iterations for fit must be positive
+        double[][] data = getTestSamples();
+        MultivariateNormalMixtureExpectationMaximization fitter =
+                new MultivariateNormalMixtureExpectationMaximization(
+                    data);
+
+        MixtureMultivariateNormalDistribution
+            initialMix = 
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+        fitter.fit(initialMix, 1000, 0);
+    }
+
+    @Test(expected = ConvergenceException.class)
+    public void testConvergenceException() {
+        // ConvergenceException thrown if fit terminates before threshold met
+        double[][] data = getTestSamples();
+        MultivariateNormalMixtureExpectationMaximization fitter
+            = new MultivariateNormalMixtureExpectationMaximization(data);
+
+        MixtureMultivariateNormalDistribution
+            initialMix = 
MultivariateNormalMixtureExpectationMaximization.estimate(data, 2);
+
+        // 5 iterations not enough to meet convergence threshold
+        fitter.fit(initialMix, 5, 1E-5);
+    }
+
+    @Test(expected = DimensionMismatchException.class)
+    public void testIncompatibleIntialMixture() {
+        // Data has 3 columns
+        double[][] data = new double[][] {
+                { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 }
+        };
+        double[] weights = new double[] { 0.5, 0.5 };
+
+        // These distributions are compatible with 2-column data, not 3-column
+        // data
+        MultivariateNormalDistribution[] mvns = new 
MultivariateNormalDistribution[2];
+
+        mvns[0] = new MultivariateNormalDistribution(new double[] {
+                        -0.0021722935000328823, 3.5432892936887908 },
+                        new double[][] {
+                                { 4.537422569229048, 3.5266152281729304 },
+                                { 3.5266152281729304, 6.175448814169779 } });
+        mvns[1] = new MultivariateNormalDistribution(new double[] {
+                        5.090902706507635, 8.68540656355283 }, new double[][] {
+                        { 2.886778573963039, 1.5257474543463154 },
+                        { 1.5257474543463154, 3.3794567673616918 } });
+
+        // Create components and mixture
+        List<Pair<Double, MultivariateNormalDistribution>> components =
+                new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
+        components.add(new Pair<Double, MultivariateNormalDistribution>(
+                weights[0], mvns[0]));
+        components.add(new Pair<Double, MultivariateNormalDistribution>(
+                weights[1], mvns[1]));
+
+        MixtureMultivariateNormalDistribution badInitialMix
+            = new MixtureMultivariateNormalDistribution(components);
+
+        MultivariateNormalMixtureExpectationMaximization fitter
+            = new MultivariateNormalMixtureExpectationMaximization(data);
+
+        fitter.fit(badInitialMix);
+    }
+
+    @Ignore@Test
+    public void testInitialMixture() {
+        // Testing initial mixture estimated from data
+        double[] correctWeights = new double[] { 0.5, 0.5 };
+
+        MultivariateNormalDistribution[] correctMVNs = new 
MultivariateNormalDistribution[2];
+
+        correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
+                        -0.0021722935000328823, 3.5432892936887908 },
+                        new double[][] {
+                                { 4.537422569229048, 3.5266152281729304 },
+                                { 3.5266152281729304, 6.175448814169779 } });
+        correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
+                        5.090902706507635, 8.68540656355283 }, new double[][] {
+                        { 2.886778573963039, 1.5257474543463154 },
+                        { 1.5257474543463154, 3.3794567673616918 } });
+
+        final MixtureMultivariateNormalDistribution initialMix
+            = 
MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(), 2);
+
+        int i = 0;
+        for (Pair<Double, MultivariateNormalDistribution> component : 
initialMix
+                .getComponents()) {
+            Assert.assertEquals(correctWeights[i], component.getFirst(),
+                    Math.ulp(1d));
+            Assert.assertEquals(correctMVNs[i], component.getSecond());
+            i++;
+        }
+    }
+
+    @Ignore@Test
+    public void testFit() {
+        // Test that the loglikelihood, weights, and models are determined and
+        // fitted correctly
+        double[][] data = getTestSamples();
+        double correctLogLikelihood = -4.292431006791994;
+        double[] correctWeights = new double[] { 0.2962324189652912, 
0.7037675810347089 };
+        MultivariateNormalDistribution[] correctMVNs = new 
MultivariateNormalDistribution[2];
+        correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
+                        -1.4213112715121132, 1.6924690505757753 },
+                        new double[][] {
+                                { 1.739356907285747, -0.5867644251487614 },
+                                { -0.5867644251487614, 1.0232932029324642 } });
+
+        correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
+                        4.213612224374709, 7.975621325853645 },
+                        new double[][] {
+                                { 4.245384898007161, 2.5797798966382155 },
+                                { 2.5797798966382155, 3.9200272522448367 } });
+
+        MultivariateNormalMixtureExpectationMaximization fitter
+            = new MultivariateNormalMixtureExpectationMaximization(data);
+
+        MixtureMultivariateNormalDistribution initialMix
+            = MultivariateNormalMixtureExpectationMaximization.estimate(data, 
2);
+        fitter.fit(initialMix);
+        MixtureMultivariateNormalDistribution fittedMix = 
fitter.getFittedModel();
+        List<Pair<Double, MultivariateNormalDistribution>> components = 
fittedMix.getComponents();
+
+        Assert.assertEquals(correctLogLikelihood,
+                            fitter.getLogLikelihood(),
+                            Math.ulp(1d));
+
+        int i = 0;
+        for (Pair<Double, MultivariateNormalDistribution> component : 
components) {
+            double weight = component.getFirst();
+            MultivariateNormalDistribution mvn = component.getSecond();
+            Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
+            Assert.assertEquals(correctMVNs[i], mvn);
+            i++;
+        }
+    }
+
+    private double[][] getTestSamples() {
+        // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
+        // [4, 8.2]
+        return new double[][] { { 7.358553610469948, 11.31260831446758 },
+                { 7.175770420124739, 8.988812210204454 },
+                { 4.324151905768422, 6.837727899051482 },
+                { 2.157832219173036, 6.317444585521968 },
+                { -1.890157421896651, 1.74271202875498 },
+                { 0.8922409354455803, 1.999119343923781 },
+                { 3.396949764787055, 6.813170372579068 },
+                { -2.057498232686068, -0.002522983830852255 },
+                { 6.359932157365045, 8.343600029975851 },
+                { 3.353102234276168, 7.087541882898689 },
+                { -1.763877221595639, 0.9688890460330644 },
+                { 6.151457185125111, 9.075011757431174 },
+                { 4.281597398048899, 5.953270070976117 },
+                { 3.549576703974894, 8.616038155992861 },
+                { 6.004706732349854, 8.959423391087469 },
+                { 2.802915014676262, 6.285676742173564 },
+                { -0.6029879029880616, 1.083332958357485 },
+                { 3.631827105398369, 6.743428504049444 },
+                { 6.161125014007315, 9.60920569689001 },
+                { -1.049582894255342, 0.2020017892080281 },
+                { 3.910573022688315, 8.19609909534937 },
+                { 8.180454017634863, 7.861055769719962 },
+                { 1.488945440439716, 8.02699903761247 },
+                { 4.813750847823778, 12.34416881332515 },
+                { 0.0443208501259158, 5.901148093240691 },
+                { 4.416417235068346, 4.465243084006094 },
+                { 4.0002433603072, 6.721937850166174 },
+                { 3.190113818788205, 10.51648348411058 },
+                { 4.493600914967883, 7.938224231022314 },
+                { -3.675669533266189, 4.472845076673303 },
+                { 6.648645511703989, 12.03544085965724 },
+                { -1.330031331404445, 1.33931042964811 },
+                { -3.812111460708707, 2.50534195568356 },
+                { 5.669339356648331, 6.214488981177026 },
+                { 1.006596727153816, 1.51165463112716 },
+                { 5.039466365033024, 7.476532610478689 },
+                { 4.349091929968925, 7.446356406259756 },
+                { -1.220289665119069, 3.403926955951437 },
+                { 5.553003979122395, 6.886518211202239 },
+                { 2.274487732222856, 7.009541508533196 },
+                { 4.147567059965864, 7.34025244349202 },
+                { 4.083882618965819, 6.362852861075623 },
+                { 2.203122344647599, 7.260295257904624 },
+                { -2.147497550770442, 1.262293431529498 },
+                { 2.473700950426512, 6.558900135505638 },
+                { 8.267081298847554, 12.10214104577748 },
+                { 6.91977329776865, 9.91998488301285 },
+                { 0.1680479852730894, 6.28286034168897 },
+                { -1.268578659195158, 2.326711221485755 },
+                { 1.829966451374701, 6.254187605304518 },
+                { 5.648849025754848, 9.330002040750291 },
+                { -2.302874793257666, 3.585545172776065 },
+                { -2.629218791709046, 2.156215538500288 },
+                { 4.036618140700114, 10.2962785719958 },
+                { 0.4616386422783874, 0.6782756325806778 },
+                { -0.3447896073408363, 0.4999834691645118 },
+                { -0.475281453118318, 1.931470384180492 },
+                { 2.382509690609731, 6.071782429815853 },
+                { -3.203934441889096, 2.572079552602468 },
+                { 8.465636032165087, 13.96462998683518 },
+                { 2.36755660870416, 5.7844595007273 },
+                { 0.5935496528993371, 1.374615871358943 },
+                { -2.467481505748694, 2.097224634713005 },
+                { 4.27867444328542, 10.24772361238549 },
+                { -2.013791907543137, 2.013799426047639 },
+                { 6.424588084404173, 9.185334939684516 },
+                { -0.8448238876802175, 0.5447382022282812 },
+                { 1.342955703473923, 8.645456317633556 },
+                { 3.108712208751979, 8.512156853800064 },
+                { 4.343205178315472, 8.056869549234374 },
+                { -2.971767642212396, 3.201180146824761 },
+                { 2.583820931523672, 5.459873414473854 },
+                { 4.209139115268925, 8.171098193546225 },
+                { 0.4064909057902746, 1.454390775518743 },
+                { 3.068642411145223, 6.959485153620035 },
+                { 6.085968972900461, 7.391429799500965 },
+                { -1.342265795764202, 1.454550012997143 },
+                { 6.249773274516883, 6.290269880772023 },
+                { 4.986225847822566, 7.75266344868907 },
+                { 7.642443254378944, 10.19914817500263 },
+                { 6.438181159163673, 8.464396764810347 },
+                { 2.520859761025108, 7.68222425260111 },
+                { 2.883699944257541, 6.777960331348503 },
+                { 2.788004550956599, 6.634735386652733 },
+                { 3.331661231995638, 5.794191300046592 },
+                { 3.526172276645504, 6.710802266815884 },
+                { 3.188298528138741, 10.34495528210205 },
+                { 0.7345539486114623, 5.807604004180681 },
+                { 1.165044595880125, 7.830121829295257 },
+                { 7.146962523500671, 11.62995162065415 },
+                { 7.813872137162087, 10.62827008714735 },
+                { 3.118099164870063, 8.286003148186371 },
+                { -1.708739286262571, 1.561026755374264 },
+                { 1.786163047580084, 4.172394388214604 },
+                { 3.718506403232386, 7.807752990130349 },
+                { 6.167414046828899, 10.01104941031293 },
+                { -1.063477247689196, 1.61176085846339 },
+                { -3.396739609433642, 0.7127911050002151 },
+                { 2.438885945896797, 7.353011138689225 },
+                { -0.2073204144780931, 0.850771146627012 }, };
+    }
+}

Propchange: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: 
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
------------------------------------------------------------------------------
    svn:keywords = Id Revision


Reply via email to