refine chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/a16a3fde Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/a16a3fde Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/a16a3fde Branch: refs/heads/JIRA-22/pr-385 Commit: a16a3fde844ba381dee7eb1e9608ddc2dcfb96fc Parents: 6dc2344 Author: amaya <g...@sapphire.in.net> Authored: Wed Sep 21 13:10:18 2016 +0900 Committer: amaya <g...@sapphire.in.net> Committed: Wed Sep 21 13:35:33 2016 +0900 ---------------------------------------------------------------------- .../hivemall/ftvec/selection/ChiSquareUDF.java | 40 +++++++------ .../java/hivemall/utils/math/StatsUtils.java | 62 +++++++++++--------- 2 files changed, 58 insertions(+), 44 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java index e2b7494..951aeeb 100644 --- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java +++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java @@ -50,6 +50,12 @@ public class ChiSquareUDF extends GenericUDF { private ListObjectInspector expectedRowOI; private PrimitiveObjectInspector expectedElOI; + private int nFeatures = -1; + private double[] observedRow = null; // to reuse + private double[] expectedRow = null; // to reuse + private double[][] observed = null; // shape = (#features, #classes) + private double[][] expected = null; // shape = (#features, #classes) + @Override public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { if (OIs.length != 2) { @@ -75,12 +81,12 @@ public class ChiSquareUDF extends GenericUDF { expectedRowOI = HiveUtils.asListOI(expectedOI.getListElementObjectInspector()); expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector()); - List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + final List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); return ObjectInspectorFactory.getStandardStructObjectInspector( - Arrays.asList("chi2_vals", "p_vals"), fieldOIs); + Arrays.asList("chi2", "pvalue"), fieldOIs); } @Override @@ -93,28 +99,28 @@ public class ChiSquareUDF extends GenericUDF { final int nClasses = observedObj.size(); Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows - int nFeatures = -1; - double[] observedRow = null; // to reuse - double[] expectedRow = null; // to reuse - double[][] observed = null; // shape = (#features, #classes) - double[][] expected = null; // shape = (#features, #classes) - // explode and transpose matrix for (int i = 0; i < nClasses; i++) { - if (i == 0) { + final Object observedObjRow = observedObj.get(i); + final Object expectedObjRow = observedObj.get(i); + + Preconditions.checkNotNull(observedObjRow); + Preconditions.checkNotNull(expectedObjRow); + + if (observedRow == null) { // init - observedRow = HiveUtils.asDoubleArray(observedObj.get(i), observedRowOI, - observedElOI, false); - expectedRow = HiveUtils.asDoubleArray(expectedObj.get(i), expectedRowOI, - expectedElOI, false); + observedRow = HiveUtils.asDoubleArray(observedObjRow, observedRowOI, observedElOI, + false); + expectedRow = HiveUtils.asDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, + false); nFeatures = observedRow.length; observed = new double[nFeatures][nClasses]; expected = new double[nFeatures][nClasses]; } else { - HiveUtils.toDoubleArray(observedObj.get(i), observedRowOI, observedElOI, - observedRow, false); - HiveUtils.toDoubleArray(expectedObj.get(i), expectedRowOI, expectedElOI, - expectedRow, false); + HiveUtils.toDoubleArray(observedObjRow, observedRowOI, observedElOI, observedRow, + false); + HiveUtils.toDoubleArray(expectedObjRow, expectedRowOI, expectedElOI, expectedRow, + false); } for (int j = 0; j < nFeatures; j++) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/a16a3fde/core/src/main/java/hivemall/utils/math/StatsUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java index d3b25c7..e255b84 100644 --- a/core/src/main/java/hivemall/utils/math/StatsUtils.java +++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java @@ -23,11 +23,15 @@ import hivemall.utils.lang.Preconditions; import javax.annotation.Nonnull; import org.apache.commons.math3.distribution.ChiSquaredDistribution; +import org.apache.commons.math3.exception.DimensionMismatchException; +import org.apache.commons.math3.exception.NotPositiveException; import org.apache.commons.math3.linear.DecompositionSolver; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularValueDecomposition; +import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.MathArrays; import java.util.AbstractMap; import java.util.Map; @@ -194,54 +198,59 @@ public final class StatsUtils { } /** - * @param observed mean vector whose value is observed - * @param expected mean vector whose value is expected + * @param observed means non-negative vector + * @param expected means positive vector * @return chi2 value */ public static double chiSquare(@Nonnull final double[] observed, @Nonnull final double[] expected) { - Preconditions.checkArgument(observed.length == expected.length); + if (observed.length < 2) { + throw new DimensionMismatchException(observed.length, 2); + } + if (expected.length != observed.length) { + throw new DimensionMismatchException(observed.length, expected.length); + } + MathArrays.checkPositive(expected); + for (double d : observed) { + if (d < 0.d) { + throw new NotPositiveException(d); + } + } double sumObserved = 0.d; double sumExpected = 0.d; - - for (int ratio = 0; ratio < observed.length; ++ratio) { - sumObserved += observed[ratio]; - sumExpected += expected[ratio]; + for (int i = 0; i < observed.length; i++) { + sumObserved += observed[i]; + sumExpected += expected[i]; } - - double var15 = 1.d; + double ratio = 1.d; boolean rescale = false; - if (Math.abs(sumObserved - sumExpected) > 1.e-5) { - var15 = sumObserved / sumExpected; + if (FastMath.abs(sumObserved - sumExpected) > 10e-6) { + ratio = sumObserved / sumExpected; rescale = true; } - double sumSq = 0.d; - - for (int i = 0; i < observed.length; ++i) { - double dev; + for (int i = 0; i < observed.length; i++) { if (rescale) { - dev = observed[i] - var15 * expected[i]; - sumSq += dev * dev / (var15 * expected[i]); + final double dev = observed[i] - ratio * expected[i]; + sumSq += dev * dev / (ratio * expected[i]); } else { - dev = observed[i] - expected[i]; + final double dev = observed[i] - expected[i]; sumSq += dev * dev / expected[i]; } } - return sumSq; } /** - * @param observed means vector whose value is observed - * @param expected means vector whose value is expected + * @param observed means non-negative vector + * @param expected means positive vector * @return p value */ public static double chiSquareTest(@Nonnull final double[] observed, @Nonnull final double[] expected) { - ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, - (double) expected.length - 1.d); + final ChiSquaredDistribution distribution = new ChiSquaredDistribution( + expected.length - 1.d); return 1.d - distribution.cumulativeProbability(chiSquare(observed, expected)); } @@ -249,8 +258,8 @@ public final class StatsUtils { * This method offers effective calculation for multiple entries rather than calculation * individually * - * @param observeds means matrix whose values are observed - * @param expecteds means matrix + * @param observeds means non-negative matrix + * @param expecteds means positive matrix * @return (chi2 value[], p value[]) */ public static Map.Entry<double[], double[]> chiSquares(@Nonnull final double[][] observeds, @@ -260,8 +269,7 @@ public final class StatsUtils { final int len = expecteds.length; final int lenOfEach = expecteds[0].length; - final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, - (double) lenOfEach - 1.d); + final ChiSquaredDistribution distribution = new ChiSquaredDistribution(lenOfEach - 1.d); final double[] chi2s = new double[len]; final double[] ps = new double[len];