change interface of chi2
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/7b07e4a6 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/7b07e4a6 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/7b07e4a6 Branch: refs/heads/JIRA-22/pr-385 Commit: 7b07e4a6e1f700ba0a6e5b68659a040a3d89aa2f Parents: d0e97e6 Author: amaya <[email protected]> Authored: Tue Sep 20 12:03:44 2016 +0900 Committer: amaya <[email protected]> Committed: Tue Sep 20 12:11:42 2016 +0900 ---------------------------------------------------------------------- .../ftvec/selection/ChiSquareTestUDF.java | 21 ---- .../hivemall/ftvec/selection/ChiSquareUDF.java | 124 +++++++++++++++++-- .../ftvec/selection/DissociationDegreeUDF.java | 88 ------------- .../java/hivemall/utils/math/StatsUtils.java | 49 ++++++-- 4 files changed, 155 insertions(+), 127 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java deleted file mode 100644 index d367085..0000000 --- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareTestUDF.java +++ /dev/null @@ -1,21 +0,0 @@ -package hivemall.ftvec.selection; - -import hivemall.utils.math.StatsUtils; -import org.apache.hadoop.hive.ql.exec.Description; - -import javax.annotation.Nonnull; - -@Description(name = "chi2_test", - value = "_FUNC_(array<number> expected, array<number> observed) - Returns p-value as double") -public class ChiSquareTestUDF extends DissociationDegreeUDF { - @Override - double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) { - return StatsUtils.chiSquareTest(expected, observed); - } - - @Override - @Nonnull - String getFuncName() { - return "chi2_test"; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java index 937b1bd..1954e33 100644 --- a/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java +++ b/core/src/main/java/hivemall/ftvec/selection/ChiSquareUDF.java @@ -1,21 +1,131 @@ package hivemall.ftvec.selection; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.hadoop.WritableUtils; +import hivemall.utils.lang.Preconditions; import hivemall.utils.math.StatsUtils; import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import javax.annotation.Nonnull; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; @Description(name = "chi2", - value = "_FUNC_(array<number> expected, array<number> observed) - Returns chi2-value as double") -public class ChiSquareUDF extends DissociationDegreeUDF { + value = "_FUNC_(array<array<number>> observed, array<array<number>> expected)" + + " - Returns chi2_val and p_val of each columns as <array<double>, array<double>>") +public class ChiSquareUDF extends GenericUDF { + private ListObjectInspector observedOI; + private ListObjectInspector observedRowOI; + private PrimitiveObjectInspector observedElOI; + private ListObjectInspector expectedOI; + private ListObjectInspector expectedRowOI; + private PrimitiveObjectInspector expectedElOI; + @Override - double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed) { - return StatsUtils.chiSquare(expected, observed); + public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { + if (OIs.length != 2) { + throw new UDFArgumentLengthException("Specify two arguments."); + } + + if (!HiveUtils.isNumberListListOI(OIs[0])){ + throw new UDFArgumentTypeException(0, "Only array<array<number>> type argument is acceptable but " + + OIs[0].getTypeName() + " was passed as `observed`"); + } + + if (!HiveUtils.isNumberListListOI(OIs[1])){ + throw new UDFArgumentTypeException(1, "Only array<array<number>> type argument is acceptable but " + + OIs[1].getTypeName() + " was passed as `expected`"); + } + + observedOI = HiveUtils.asListOI(OIs[1]); + observedRowOI=HiveUtils.asListOI(observedOI.getListElementObjectInspector()); + observedElOI = HiveUtils.asDoubleCompatibleOI( observedRowOI.getListElementObjectInspector()); + expectedOI = HiveUtils.asListOI(OIs[0]); + expectedRowOI=HiveUtils.asListOI(expectedOI.getListElementObjectInspector()); + expectedElOI = HiveUtils.asDoubleCompatibleOI(expectedRowOI.getListElementObjectInspector()); + + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + + return ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("chi2_vals", "p_vals"), fieldOIs); + } + + @Override + public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException { + List observedObj = observedOI.getList(dObj[0].get()); // shape = (#classes, #features) + List expectedObj = expectedOI.getList(dObj[1].get()); // shape = (#classes, #features) + + Preconditions.checkNotNull(observedObj); + Preconditions.checkNotNull(expectedObj); + final int nClasses = observedObj.size(); + Preconditions.checkArgument(nClasses == expectedObj.size()); // same #rows + + int nFeatures=-1; + double[] observedRow=null; // to reuse + double[] expectedRow=null; // to reuse + double[][] observed =null; // shape = (#features, #classes) + double[][] expected = null; // shape = (#features, #classes) + + // explode and transpose matrix + for(int i=0;i<nClasses;i++){ + if(i==0){ + // init + observedRow=HiveUtils.asDoubleArray(observedObj.get(i),observedRowOI,observedElOI,false); + expectedRow=HiveUtils.asDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI, false); + nFeatures = observedRow.length; + observed=new double[nFeatures][nClasses]; + expected = new double[nFeatures][nClasses]; + }else{ + HiveUtils.toDoubleArray(observedObj.get(i),observedRowOI,observedElOI,observedRow,false); + HiveUtils.toDoubleArray(expectedObj.get(i),expectedRowOI,expectedElOI,expectedRow, false); + } + + for(int j=0;j<nFeatures;j++){ + observed[j][i] = observedRow[j]; + expected[j][i] = expectedRow[j]; + } + } + + final Map.Entry<double[],double[]> chi2 = StatsUtils.chiSquares(observed,expected); + + final Object[] result = new Object[2]; + result[0] = WritableUtils.toWritableList(chi2.getKey()); + result[1]=WritableUtils.toWritableList(chi2.getValue()); + return result; } @Override - @Nonnull - String getFuncName() { - return "chi2"; + public String getDisplayString(String[] children) { + final StringBuilder sb = new StringBuilder(); + sb.append("chi2"); + sb.append("("); + if (children.length > 0) { + sb.append(children[0]); + for (int i = 1; i < children.length; i++) { + sb.append(", "); + sb.append(children[i]); + } + } + sb.append(")"); + return sb.toString(); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java b/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java deleted file mode 100644 index 0acae82..0000000 --- a/core/src/main/java/hivemall/ftvec/selection/DissociationDegreeUDF.java +++ /dev/null @@ -1,88 +0,0 @@ -package hivemall.ftvec.selection; - -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.Preconditions; -import hivemall.utils.math.StatsUtils; -import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentException; -import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; -import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; -import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; - -import javax.annotation.Nonnull; - -@Description(name = "", - value = "_FUNC_(array<number> expected, array<number> observed) - Returns dissociation degree as double") -public abstract class DissociationDegreeUDF extends GenericUDF { - private ListObjectInspector expectedOI; - private DoubleObjectInspector expectedElOI; - private ListObjectInspector observedOI; - private DoubleObjectInspector observedElOI; - - @Override - public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException { - if (OIs.length != 2) { - throw new UDFArgumentLengthException("Specify two arguments."); - } - - if (!HiveUtils.isListOI(OIs[0]) - || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[0]).getListElementObjectInspector())){ - throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but " - + OIs[0].getTypeName() + " was passed as `expected`"); - } - - if (!HiveUtils.isListOI(OIs[1]) - || !HiveUtils.isNumberOI(((ListObjectInspector) OIs[1]).getListElementObjectInspector())){ - throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but " - + OIs[1].getTypeName() + " was passed as `observed`"); - } - - expectedOI = (ListObjectInspector) OIs[0]; - expectedElOI = (DoubleObjectInspector) expectedOI.getListElementObjectInspector(); - observedOI = (ListObjectInspector) OIs[1]; - observedElOI = (DoubleObjectInspector) observedOI.getListElementObjectInspector(); - - return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; - } - - @Override - public Object evaluate(GenericUDF.DeferredObject[] dObj) throws HiveException { - final double[] expected = HiveUtils.asDoubleArray(dObj[0].get(),expectedOI,expectedElOI); - final double[] observed = HiveUtils.asDoubleArray(dObj[1].get(),observedOI,observedElOI); - - Preconditions.checkNotNull(expected); - Preconditions.checkNotNull(observed); - Preconditions.checkArgument(expected.length == observed.length); - - final double dissociation = calcDissociation(expected,observed); - - return new DoubleWritable(dissociation); - } - - @Override - public String getDisplayString(String[] children) { - final StringBuilder sb = new StringBuilder(); - sb.append(getFuncName()); - sb.append("("); - if (children.length > 0) { - sb.append(children[0]); - for (int i = 1; i < children.length; i++) { - sb.append(", "); - sb.append(children[i]); - } - } - sb.append(")"); - return sb.toString(); - } - - abstract double calcDissociation(@Nonnull final double[] expected,@Nonnull final double[] observed); - - @Nonnull - abstract String getFuncName(); -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/7b07e4a6/core/src/main/java/hivemall/utils/math/StatsUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/StatsUtils.java b/core/src/main/java/hivemall/utils/math/StatsUtils.java index 7633419..f9d0f30 100644 --- a/core/src/main/java/hivemall/utils/math/StatsUtils.java +++ b/core/src/main/java/hivemall/utils/math/StatsUtils.java @@ -29,6 +29,9 @@ import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.commons.math3.linear.SingularValueDecomposition; +import java.util.AbstractMap; +import java.util.Map; + public final class StatsUtils { private StatsUtils() {} @@ -191,24 +194,24 @@ public final class StatsUtils { } /** - * @param expected mean vector whose value is expected * @param observed mean vector whose value is observed - * @return chi2-value + * @param expected mean vector whose value is expected + * @return chi2 value */ - public static double chiSquare(@Nonnull final double[] expected, @Nonnull final double[] observed) { - Preconditions.checkArgument(expected.length == observed.length); + public static double chiSquare(@Nonnull final double[] observed, @Nonnull final double[] expected) { + Preconditions.checkArgument(observed.length == expected.length); - double sumExpected = 0.d; double sumObserved = 0.d; + double sumExpected = 0.d; for (int ratio = 0; ratio < observed.length; ++ratio) { - sumExpected += expected[ratio]; sumObserved += observed[ratio]; + sumExpected += expected[ratio]; } double var15 = 1.d; boolean rescale = false; - if (Math.abs(sumExpected - sumObserved) > 1.e-5) { + if (Math.abs(sumObserved - sumExpected) > 1.e-5) { var15 = sumObserved / sumExpected; rescale = true; } @@ -230,12 +233,36 @@ public final class StatsUtils { } /** - * @param expected means vector whose value is expected * @param observed means vector whose value is observed - * @return p-value + * @param expected means vector whose value is expected + * @return p value */ - public static double chiSquareTest(@Nonnull final double[] expected,@Nonnull final double[] observed) { + public static double chiSquareTest(@Nonnull final double[] observed, @Nonnull final double[] expected) { ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)expected.length - 1.d); - return 1.d - distribution.cumulativeProbability(chiSquare(expected, observed)); + return 1.d - distribution.cumulativeProbability(chiSquare(observed,expected)); + } + + /** + * This method offers effective calculation for multiple entries rather than calculation individually + * @param observeds means matrix whose values are observed + * @param expecteds means matrix + * @return (chi2 value[], p value[]) + */ + public static Map.Entry<double[],double[]> chiSquares(@Nonnull final double[][] observeds, @Nonnull final double[][] expecteds){ + Preconditions.checkArgument(observeds.length == expecteds.length); + + final int len = expecteds.length; + final int lenOfEach = expecteds[0].length; + + final ChiSquaredDistribution distribution = new ChiSquaredDistribution(null, (double)lenOfEach - 1.d); + + final double[] chi2s = new double[len]; + final double[] ps = new double[len]; + for(int i=0;i<len;i++){ + chi2s[i] = chiSquare(observeds[i],expecteds[i]); + ps[i] = 1.d - distribution.cumulativeProbability(chi2s[i]); + } + + return new AbstractMap.SimpleEntry<double[], double[]>(chi2s,ps); } }
